import os import re import requests import argparse import asyncio import random from typing import List, Optional, Dict from concurrent.futures import ThreadPoolExecutor import chardet import aiohttp import bs4 import uvicorn from fastapi import FastAPI from pydantic import BaseModel from googleapiclient.discovery import build # --- CLI Args --- parser = argparse.ArgumentParser(description="Launch online search server.") parser.add_argument('--api_key', type=str, required=True, help="API key for Google search") parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search") parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query") parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.") args = parser.parse_args() # --- Config --- class OnlineSearchConfig: def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False): self.topk = topk self.api_key = api_key self.cse_id = cse_id self.snippet_only = snippet_only # --- Utilities --- def parse_snippet(snippet: str) -> List[str]: segments = snippet.split("...") return [s.strip() for s in segments if len(s.strip().split()) > 5] def sanitize_search_query(query: str) -> str: # Remove or replace special characters that might cause issues. # This is a basic example; you might need to add more characters or patterns. sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces. sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces. sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces. return sanitized_query def filter_links(search_results: List[Dict]) -> List[str]: links = [] for result in search_results: for item in result.get("items", []): if "mime" in item: continue ext = os.path.splitext(item["link"])[1] if ext in ["", ".html", ".htm", ".shtml"]: links.append(item["link"]) return links async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str: user_agents = [ "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...", "Mozilla/5.0 AppleWebKit/537.36...", "Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)", ] headers = {"User-Agent": random.choice(user_agents)} async with semaphore: try: async with session.get(url, headers=headers) as response: raw = await response.read() detected = chardet.detect(raw) encoding = detected["encoding"] or "utf-8" return raw.decode(encoding, errors="ignore") except (aiohttp.ClientError, asyncio.TimeoutError): return "" async def fetch_all(urls: List[str], limit: int = 8) -> List[str]: semaphore = asyncio.Semaphore(limit) timeout = aiohttp.ClientTimeout(total=5) connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True) async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: tasks = [fetch(session, url, semaphore) for url in urls] return await asyncio.gather(*tasks) # --- Search Engine --- class OnlineSearchEngine: def __init__(self, config: OnlineSearchConfig): self.config = config def collect_context(self, snippet: str, doc: str) -> str: snippets = parse_snippet(snippet) ctx_paras = [] for s in snippets: pos = doc.replace("\n", " ").find(s) if pos == -1: continue sta = pos while sta > 0 and doc[sta] != "\n": sta -= 1 end = pos + len(s) while end < len(doc) and doc[end] != "\n": end += 1 para = doc[sta:end].strip() if para not in ctx_paras: ctx_paras.append(para) return "\n".join(ctx_paras) def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]: links = filter_links(search_results) contents = asyncio.run(fetch_all(links)) content_dict = {} for html, link in zip(contents, links): soup = bs4.BeautifulSoup(html, "html.parser") text = "\n".join([p.get_text() for p in soup.find_all("p")]) content_dict[link] = text return content_dict def search(self, search_term: str, num_iter: int = 1) -> List[Dict]: service = build('customsearch', 'v1', developerKey=self.config.api_key) results = [] sanitize_search_term = sanitize_search_query(search_term) if search_term.isspace(): return results res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute() results.append(res) for _ in range(num_iter - 1): if 'nextPage' not in res.get('queries', {}): break start_idx = res['queries']['nextPage'][0]['startIndex'] res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute() results.append(res) return results def batch_search(self, queries: List[str]) -> List[List[str]]: with ThreadPoolExecutor() as executor: return list(executor.map(self._retrieve_context, queries)) def _retrieve_context(self, query: str) -> List[str]: if self.config.snippet_only: search_results = self.search(query) contexts = [] for result in search_results: for item in result.get("items", []): title = item.get("title", "") context = ' '.join(parse_snippet(item.get("snippet", ""))) if title != "" or context != "": title = "No title." if not title else title context = "No snippet available." if not context else context contexts.append({ 'document': {"contents": f'\"{title}\"\n{context}'}, }) else: content_dict = self.fetch_web_content(search_results) contexts = [] for result in search_results: for item in result.get("items", []): link = item["link"] title = item.get("title", "") snippet = item.get("snippet", "") if link in content_dict: context = self.collect_context(snippet, content_dict[link]) if title != "" or context != "": title = "No title." if not title else title context = "No snippet available." if not context else context contexts.append({ 'document': {"contents": f'\"{title}\"\n{context}'}, }) return contexts[:self.config.topk] # --- FastAPI App --- app = FastAPI(title="Online Search Proxy Server") class SearchRequest(BaseModel): queries: List[str] config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only) engine = OnlineSearchEngine(config) @app.post("/retrieve") def search_endpoint(request: SearchRequest): results = engine.batch_search(request.queries) return {"result": results} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)