Files
Search-R1/search_r1/search/google_search_server.py

203 lines
7.7 KiB
Python

import os
import re
import requests
import argparse
import asyncio
import random
from typing import List, Optional, Dict
from concurrent.futures import ThreadPoolExecutor
import chardet
import aiohttp
import bs4
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from googleapiclient.discovery import build
# --- CLI Args ---
parser = argparse.ArgumentParser(description="Launch online search server.")
parser.add_argument('--api_key', type=str, required=True, help="API key for Google search")
parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search")
parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query")
parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.")
args = parser.parse_args()
# --- Config ---
class OnlineSearchConfig:
def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False):
self.topk = topk
self.api_key = api_key
self.cse_id = cse_id
self.snippet_only = snippet_only
# --- Utilities ---
def parse_snippet(snippet: str) -> List[str]:
segments = snippet.split("...")
return [s.strip() for s in segments if len(s.strip().split()) > 5]
def sanitize_search_query(query: str) -> str:
# Remove or replace special characters that might cause issues.
# This is a basic example; you might need to add more characters or patterns.
sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces.
sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces.
sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces.
return sanitized_query
def filter_links(search_results: List[Dict]) -> List[str]:
links = []
for result in search_results:
for item in result.get("items", []):
if "mime" in item:
continue
ext = os.path.splitext(item["link"])[1]
if ext in ["", ".html", ".htm", ".shtml"]:
links.append(item["link"])
return links
async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
user_agents = [
"Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
"Mozilla/5.0 AppleWebKit/537.36...",
"Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
]
headers = {"User-Agent": random.choice(user_agents)}
async with semaphore:
try:
async with session.get(url, headers=headers) as response:
raw = await response.read()
detected = chardet.detect(raw)
encoding = detected["encoding"] or "utf-8"
return raw.decode(encoding, errors="ignore")
except (aiohttp.ClientError, asyncio.TimeoutError):
return ""
async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
semaphore = asyncio.Semaphore(limit)
timeout = aiohttp.ClientTimeout(total=5)
connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
tasks = [fetch(session, url, semaphore) for url in urls]
return await asyncio.gather(*tasks)
# --- Search Engine ---
class OnlineSearchEngine:
def __init__(self, config: OnlineSearchConfig):
self.config = config
def collect_context(self, snippet: str, doc: str) -> str:
snippets = parse_snippet(snippet)
ctx_paras = []
for s in snippets:
pos = doc.replace("\n", " ").find(s)
if pos == -1:
continue
sta = pos
while sta > 0 and doc[sta] != "\n":
sta -= 1
end = pos + len(s)
while end < len(doc) and doc[end] != "\n":
end += 1
para = doc[sta:end].strip()
if para not in ctx_paras:
ctx_paras.append(para)
return "\n".join(ctx_paras)
def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]:
links = filter_links(search_results)
contents = asyncio.run(fetch_all(links))
content_dict = {}
for html, link in zip(contents, links):
soup = bs4.BeautifulSoup(html, "html.parser")
text = "\n".join([p.get_text() for p in soup.find_all("p")])
content_dict[link] = text
return content_dict
def search(self, search_term: str, num_iter: int = 1) -> List[Dict]:
service = build('customsearch', 'v1', developerKey=self.config.api_key)
results = []
sanitize_search_term = sanitize_search_query(search_term)
if search_term.isspace():
return results
res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute()
results.append(res)
for _ in range(num_iter - 1):
if 'nextPage' not in res.get('queries', {}):
break
start_idx = res['queries']['nextPage'][0]['startIndex']
res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute()
results.append(res)
return results
def batch_search(self, queries: List[str]) -> List[List[str]]:
with ThreadPoolExecutor() as executor:
return list(executor.map(self._retrieve_context, queries))
def _retrieve_context(self, query: str) -> List[str]:
if self.config.snippet_only:
search_results = self.search(query)
contexts = []
for result in search_results:
for item in result.get("items", []):
title = item.get("title", "")
context = ' '.join(parse_snippet(item.get("snippet", "")))
if title != "" or context != "":
title = "No title." if not title else title
context = "No snippet available." if not context else context
contexts.append({
'document': {"contents": f'\"{title}\"\n{context}'},
})
else:
content_dict = self.fetch_web_content(search_results)
contexts = []
for result in search_results:
for item in result.get("items", []):
link = item["link"]
title = item.get("title", "")
snippet = item.get("snippet", "")
if link in content_dict:
context = self.collect_context(snippet, content_dict[link])
if title != "" or context != "":
title = "No title." if not title else title
context = "No snippet available." if not context else context
contexts.append({
'document': {"contents": f'\"{title}\"\n{context}'},
})
return contexts[:self.config.topk]
# --- FastAPI App ---
app = FastAPI(title="Online Search Proxy Server")
class SearchRequest(BaseModel):
queries: List[str]
config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only)
engine = OnlineSearchEngine(config)
@app.post("/retrieve")
def search_endpoint(request: SearchRequest):
results = engine.batch_search(request.queries)
return {"result": results}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)