add local sparse retriever, ann dense retriever and online search engine
This commit is contained in:
202
search_r1/search/google_search_server.py
Normal file
202
search_r1/search/google_search_server.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Optional, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import chardet
|
||||
import aiohttp
|
||||
import bs4
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
||||
# --- CLI Args ---
|
||||
parser = argparse.ArgumentParser(description="Launch online search server.")
|
||||
parser.add_argument('--api_key', type=str, required=True, help="API key for Google search")
|
||||
parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search")
|
||||
parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query")
|
||||
parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# --- Config ---
|
||||
class OnlineSearchConfig:
|
||||
def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False):
|
||||
self.topk = topk
|
||||
self.api_key = api_key
|
||||
self.cse_id = cse_id
|
||||
self.snippet_only = snippet_only
|
||||
|
||||
|
||||
# --- Utilities ---
|
||||
def parse_snippet(snippet: str) -> List[str]:
|
||||
segments = snippet.split("...")
|
||||
return [s.strip() for s in segments if len(s.strip().split()) > 5]
|
||||
|
||||
|
||||
def sanitize_search_query(query: str) -> str:
|
||||
# Remove or replace special characters that might cause issues.
|
||||
# This is a basic example; you might need to add more characters or patterns.
|
||||
sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces.
|
||||
sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces.
|
||||
sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces.
|
||||
|
||||
return sanitized_query
|
||||
|
||||
|
||||
def filter_links(search_results: List[Dict]) -> List[str]:
|
||||
links = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
if "mime" in item:
|
||||
continue
|
||||
ext = os.path.splitext(item["link"])[1]
|
||||
if ext in ["", ".html", ".htm", ".shtml"]:
|
||||
links.append(item["link"])
|
||||
return links
|
||||
|
||||
|
||||
async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
|
||||
user_agents = [
|
||||
"Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
|
||||
"Mozilla/5.0 AppleWebKit/537.36...",
|
||||
"Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
|
||||
]
|
||||
headers = {"User-Agent": random.choice(user_agents)}
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
raw = await response.read()
|
||||
detected = chardet.detect(raw)
|
||||
encoding = detected["encoding"] or "utf-8"
|
||||
return raw.decode(encoding, errors="ignore")
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError):
|
||||
return ""
|
||||
|
||||
|
||||
async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
|
||||
semaphore = asyncio.Semaphore(limit)
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
tasks = [fetch(session, url, semaphore) for url in urls]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# --- Search Engine ---
|
||||
class OnlineSearchEngine:
|
||||
def __init__(self, config: OnlineSearchConfig):
|
||||
self.config = config
|
||||
|
||||
def collect_context(self, snippet: str, doc: str) -> str:
|
||||
snippets = parse_snippet(snippet)
|
||||
ctx_paras = []
|
||||
|
||||
for s in snippets:
|
||||
pos = doc.replace("\n", " ").find(s)
|
||||
if pos == -1:
|
||||
continue
|
||||
sta = pos
|
||||
while sta > 0 and doc[sta] != "\n":
|
||||
sta -= 1
|
||||
end = pos + len(s)
|
||||
while end < len(doc) and doc[end] != "\n":
|
||||
end += 1
|
||||
para = doc[sta:end].strip()
|
||||
if para not in ctx_paras:
|
||||
ctx_paras.append(para)
|
||||
|
||||
return "\n".join(ctx_paras)
|
||||
|
||||
def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]:
|
||||
links = filter_links(search_results)
|
||||
contents = asyncio.run(fetch_all(links))
|
||||
content_dict = {}
|
||||
for html, link in zip(contents, links):
|
||||
soup = bs4.BeautifulSoup(html, "html.parser")
|
||||
text = "\n".join([p.get_text() for p in soup.find_all("p")])
|
||||
content_dict[link] = text
|
||||
return content_dict
|
||||
|
||||
def search(self, search_term: str, num_iter: int = 1) -> List[Dict]:
|
||||
service = build('customsearch', 'v1', developerKey=self.config.api_key)
|
||||
results = []
|
||||
sanitize_search_term = sanitize_search_query(search_term)
|
||||
if search_term.isspace():
|
||||
return results
|
||||
res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute()
|
||||
results.append(res)
|
||||
|
||||
for _ in range(num_iter - 1):
|
||||
if 'nextPage' not in res.get('queries', {}):
|
||||
break
|
||||
start_idx = res['queries']['nextPage'][0]['startIndex']
|
||||
res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute()
|
||||
results.append(res)
|
||||
|
||||
return results
|
||||
|
||||
def batch_search(self, queries: List[str]) -> List[List[str]]:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
return list(executor.map(self._retrieve_context, queries))
|
||||
|
||||
def _retrieve_context(self, query: str) -> List[str]:
|
||||
|
||||
if self.config.snippet_only:
|
||||
search_results = self.search(query)
|
||||
contexts = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
title = item.get("title", "")
|
||||
context = ' '.join(parse_snippet(item.get("snippet", "")))
|
||||
if title != "" or context != "":
|
||||
title = "No title." if not title else title
|
||||
context = "No snippet available." if not context else context
|
||||
contexts.append({
|
||||
'document': {"contents": f'\"{title}\"\n{context}'},
|
||||
})
|
||||
else:
|
||||
content_dict = self.fetch_web_content(search_results)
|
||||
contexts = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
link = item["link"]
|
||||
title = item.get("title", "")
|
||||
snippet = item.get("snippet", "")
|
||||
if link in content_dict:
|
||||
context = self.collect_context(snippet, content_dict[link])
|
||||
if title != "" or context != "":
|
||||
title = "No title." if not title else title
|
||||
context = "No snippet available." if not context else context
|
||||
contexts.append({
|
||||
'document': {"contents": f'\"{title}\"\n{context}'},
|
||||
})
|
||||
|
||||
return contexts[:self.config.topk]
|
||||
|
||||
|
||||
# --- FastAPI App ---
|
||||
app = FastAPI(title="Online Search Proxy Server")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only)
|
||||
engine = OnlineSearchEngine(config)
|
||||
|
||||
@app.post("/retrieve")
|
||||
def search_endpoint(request: SearchRequest):
|
||||
results = engine.batch_search(request.queries)
|
||||
return {"result": results}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -20,7 +20,9 @@ parser = argparse.ArgumentParser(description="Launch the local faiss retriever."
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -335,11 +337,11 @@ app = FastAPI()
|
||||
# 1) Build a config (could also parse from arguments).
|
||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||
config = Config(
|
||||
retrieval_method = "e5", # or "dense"
|
||||
retrieval_method = args.retriever_name, # or "dense"
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.topk,
|
||||
faiss_gpu=True,
|
||||
faiss_gpu=args.faiss_gpu,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
|
||||
112
search_r1/search/serp_search_server.py
Normal file
112
search_r1/search/serp_search_server.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import requests
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch online search server.")
|
||||
parser.add_argument('--search_url', type=str, required=True,
|
||||
help="URL for search engine (e.g. https://serpapi.com/search)")
|
||||
parser.add_argument('--topk', type=int, default=3,
|
||||
help="Number of results to return per query")
|
||||
parser.add_argument('--serp_api_key', type=str, default=None,
|
||||
help="SerpAPI key for online search")
|
||||
parser.add_argument('--serp_engine', type=str, default="google",
|
||||
help="SerpAPI engine for online search")
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Config ---
|
||||
class OnlineSearchConfig:
|
||||
def __init__(
|
||||
self,
|
||||
search_url: str = "https://serpapi.com/search",
|
||||
topk: int = 3,
|
||||
serp_api_key: Optional[str] = None,
|
||||
serp_engine: Optional[str] = None,
|
||||
):
|
||||
self.search_url = search_url
|
||||
self.topk = topk
|
||||
self.serp_api_key = serp_api_key
|
||||
self.serp_engine = serp_engine
|
||||
|
||||
|
||||
# --- Online Search Wrapper ---
|
||||
class OnlineSearchEngine:
|
||||
def __init__(self, config: OnlineSearchConfig):
|
||||
self.config = config
|
||||
|
||||
def _search_query(self, query: str):
|
||||
params = {
|
||||
"engine": self.config.serp_engine,
|
||||
"q": query,
|
||||
"api_key": self.config.serp_api_key,
|
||||
}
|
||||
response = requests.get(self.config.search_url, params=params)
|
||||
return response.json()
|
||||
|
||||
def batch_search(self, queries: List[str]):
|
||||
results = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for result in executor.map(self._search_query, queries):
|
||||
results.append(self._process_result(result))
|
||||
return results
|
||||
|
||||
def _process_result(self, search_result: Dict):
|
||||
results = []
|
||||
|
||||
answer_box = search_result.get('answer_box', {})
|
||||
if answer_box:
|
||||
title = answer_box.get('title', 'No title.')
|
||||
snippet = answer_box.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
organic_results = search_result.get('organic_results', [])
|
||||
for _, result in enumerate(organic_results[:self.config.topk]):
|
||||
title = result.get('title', 'No title.')
|
||||
snippet = result.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
related_results = search_result.get('related_questions', [])
|
||||
for _, result in enumerate(related_results[:self.config.topk]):
|
||||
title = result.get('question', 'No title.') # question is the title here
|
||||
snippet = result.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# --- FastAPI Setup ---
|
||||
app = FastAPI(title="Online Search Proxy Server")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
# Instantiate global config + engine
|
||||
config = OnlineSearchConfig(
|
||||
search_url=args.search_url,
|
||||
topk=args.topk,
|
||||
serp_api_key=args.serp_api_key,
|
||||
serp_engine=args.serp_engine,
|
||||
)
|
||||
engine = OnlineSearchEngine(config)
|
||||
|
||||
# --- Routes ---
|
||||
@app.post("/retrieve")
|
||||
def search_endpoint(request: SearchRequest):
|
||||
results = engine.batch_search(request.queries)
|
||||
return {"result": results}
|
||||
|
||||
## return {"result": List[List[{'document': {"id": xx, "content": "title" + \n + "content"}, 'score': xx}]]}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user