add local sparse retriever, ann dense retriever and online search engine
This commit is contained in:
@@ -50,7 +50,7 @@ conda activate retriever
|
||||
|
||||
# we recommend installing torch with conda for faiss-gpu
|
||||
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||
pip install transformers datasets
|
||||
pip install transformers datasets pyserini
|
||||
|
||||
## install the gpu version faiss to guarantee efficient RL rollout
|
||||
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
|
||||
@@ -163,6 +163,8 @@ You can change ```retriever_name``` and ```retriever_model``` to your interested
|
||||
|
||||
## Use your own search engine
|
||||
|
||||
Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md).
|
||||
|
||||
The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline.
|
||||
|
||||
The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
|
||||
@@ -170,7 +172,7 @@ The LLM can call the search engine by calling the search API (e.g., "http://127.
|
||||
You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
|
||||
|
||||
## To do
|
||||
- Support google search / bing search / brave search API and others.
|
||||
- Support google search / bing search / brave search API and others. ✔️
|
||||
- Support LoRA tuning.
|
||||
- Support supervised finetuning.
|
||||
- Support off-the-shelf rerankers.
|
||||
|
||||
127
docs/retriever.md
Normal file
127
docs/retriever.md
Normal file
@@ -0,0 +1,127 @@
|
||||
|
||||
## Search Engine
|
||||
|
||||
In this document, we provide examples of how to launch different retrievers, including local sparse retriever (e.g., BM25), local dense retriever (e.g., e5) and online search engine.
|
||||
For local retrievers, we use [wiki-18](https://huggingface.co/datasets/PeterJinGo/wiki-18-corpus) corpus as an example and the corpus indexing can be found at [bm25](https://huggingface.co/datasets/PeterJinGo/wiki-18-bm25-index), [e5-flat](https://huggingface.co/datasets/PeterJinGo/wiki-18-e5-index), [e5-HNSW64](PeterJinGo/wiki-18-e5-index-HNSW64).
|
||||
|
||||
### How to choose the retriever?
|
||||
|
||||
- If you have a private or domain-specific corpus, choose **local retriever**.
|
||||
|
||||
- If there is no high quality embedding-based retrievers (dense retrievers) in your domain, choose **sparse local retriever** (e.g., BM25).
|
||||
|
||||
- Otherwise choose **dense local retriever**.
|
||||
|
||||
- If you do not have sufficent GPUs to conduct exact dense embedding matching, choose **ANN indexing** on CPUs.
|
||||
|
||||
- If you have sufficient GPUs, choose **flat indexing** on GPUs.
|
||||
|
||||
|
||||
- If you want to train a general LLM search agent and have enough funding, choose **online search engine** (e.g., [SerpAPI](https://serpapi.com/)).
|
||||
|
||||
|
||||
- If you have a domain specific online search engine (e.g., PubMed search), you can refer to [link](https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/serp_search_server.py) to integrate it to Search-R1 by yourself.
|
||||
|
||||
|
||||
### Local Sparse Retriever
|
||||
|
||||
Sparse retriever (e.g., bm25) is a traditional method. The retrieval process is very efficient and no GPUs are needed. However, it may not be as accurate as dense retrievers in some specific domain.
|
||||
|
||||
(1) Download the indexing.
|
||||
```bash
|
||||
save_path=/your/path/to/save
|
||||
huggingface-cli download PeterJinGo/wiki-18-bm25-index --repo-type dataset --local-dir $save_path
|
||||
```
|
||||
|
||||
(2) Launch a local BM25 retriever server.
|
||||
```bash
|
||||
conda activate retriever
|
||||
|
||||
index_file=$save_path/bm25
|
||||
corpus_file=$save_path/wiki-18.jsonl
|
||||
retriever_name=bm25
|
||||
|
||||
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name
|
||||
```
|
||||
|
||||
|
||||
### Local Dense Retriever
|
||||
|
||||
You can also adopt some off-the-shelf dense retrievers, e.g., e5. These models are much stronger than sparse retriever in some specific domains.
|
||||
If you have sufficient GPU, we would recommend the flat indexing variant below, otherwise you can adopt the ANN variant.
|
||||
|
||||
#### Flat indexing
|
||||
|
||||
Flat indexing conducts exact embedding match, which is slow but very accurate. To make it efficient enough to support online RL, we would recommend enable **GPU** usage by ```--faiss_gpu```.
|
||||
|
||||
(1) Download the indexing and corpus.
|
||||
```bash
|
||||
save_path=/the/path/to/save
|
||||
python scripts/download.py --save_path $save_path
|
||||
cat $save_path/part_* > $save_path/e5_Flat.index
|
||||
gzip -d $save_path/wiki-18.jsonl.gz
|
||||
```
|
||||
|
||||
(2) Launch a local flat e5 retriever server.
|
||||
|
||||
```bash
|
||||
conda activate retriever
|
||||
|
||||
index_file=$save_path/e5_Flat.index
|
||||
corpus_file=$save_path/wiki-18.jsonl
|
||||
retriever_name=e5
|
||||
retriever_path=intfloat/e5-base-v2
|
||||
|
||||
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name --retriever_model $retriever_path --faiss_gpu
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### ANN indexing (HNSW64)
|
||||
|
||||
To improve the search efficient with only **CPU**, you can adopt approximate nearest neighbor (ANN) indexing, e.g., with HNSW64.
|
||||
It is very efficient, but may not be as accurate as flat indexing, especially when the number of retrieved passages is small.
|
||||
|
||||
(1) Download the indexing.
|
||||
```bash
|
||||
save_path=/the/path/to/save
|
||||
huggingface-cli download PeterJinGo/wiki-18-e5-index-HNSW64 --repo-type dataset --local-dir $save_path
|
||||
cat $save_path/part_* > $save_path/e5_HNSW64.index
|
||||
```
|
||||
|
||||
|
||||
(2) Launch a local BM25 retriever server.
|
||||
```bash
|
||||
conda activate retriever
|
||||
|
||||
index_file=$save_path/e5_HNSW64.index
|
||||
corpus_file=$save_path/wiki-18.jsonl
|
||||
retriever_name=e5
|
||||
retriever_path=intfloat/e5-base-v2
|
||||
|
||||
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name --retriever_model $retriever_path
|
||||
```
|
||||
|
||||
|
||||
### Online Search Engine
|
||||
|
||||
We support both [Google Search API](https://developers.google.com/custom-search/v1/overview) and [SerpAPI](https://serpapi.com/). We would recommend [SerpAPI](https://serpapi.com/) since it integrates multiple online search engine APIs (including Google, Bing, Baidu, etc) and does not have a monthly quota limitation ([Google Search API](https://developers.google.com/custom-search/v1/overview) has a hard 10k monthly quota, which is not sufficient to fulfill online LLM RL training).
|
||||
|
||||
#### SerAPI online search server
|
||||
|
||||
```bash
|
||||
search_url=https://serpapi.com/search
|
||||
serp_api_key="" # put your serp api key here (https://serpapi.com/)
|
||||
|
||||
python search_r1/search/online_search_server.py --search_url $search_url --topk 3 --serp_api_key $serp_api_key
|
||||
```
|
||||
|
||||
#### Google online search server
|
||||
|
||||
```bash
|
||||
api_key="" # put your google custom API key here (https://developers.google.com/custom-search/v1/overview)
|
||||
cse_id="" # put your google cse API key here (https://developers.google.com/custom-search/v1/overview)
|
||||
|
||||
python search_r1/search/internal_google_server.py --api_key $api_key --topk 5 --cse_id $cse_id --snippet_only
|
||||
```
|
||||
|
||||
8
example/retriever/retrieval_launch_google.sh
Normal file
8
example/retriever/retrieval_launch_google.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
api_key="" # put your google custom API key here (https://developers.google.com/custom-search/v1/overview)
|
||||
cse_id="" # put your google cse API key here (https://developers.google.com/custom-search/v1/overview)
|
||||
|
||||
python search_r1/search/internal_google_server.py --api_key $api_key \
|
||||
--topk 5 \
|
||||
--cse_id $cse_id \
|
||||
--snippet_only
|
||||
7
example/retriever/retrieval_launch_serpapi.sh
Normal file
7
example/retriever/retrieval_launch_serpapi.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
search_url=https://serpapi.com/search
|
||||
serp_api_key="" # put your serp api key here (https://serpapi.com/)
|
||||
|
||||
python search_r1/search/online_search_server.py --search_url $search_url \
|
||||
--topk 3 \
|
||||
--serp_api_key $serp_api_key
|
||||
@@ -2,9 +2,12 @@
|
||||
file_path=/the/path/you/save/corpus
|
||||
index_file=$file_path/e5_Flat.index
|
||||
corpus_file=$file_path/wiki-18.jsonl
|
||||
retriever=intfloat/e5-base-v2
|
||||
retriever_name=e5
|
||||
retriever_path=intfloat/e5-base-v2
|
||||
|
||||
python search_r1/search/retrieval_server.py --index_path $index_file \
|
||||
--corpus_path $corpus_file \
|
||||
--topk 3 \
|
||||
--retriever_model $retriever
|
||||
--retriever_name $retriever_name \
|
||||
--retriever_model $retriever_path \
|
||||
--faiss_gpu
|
||||
|
||||
202
search_r1/search/google_search_server.py
Normal file
202
search_r1/search/google_search_server.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Optional, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import chardet
|
||||
import aiohttp
|
||||
import bs4
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
||||
# --- CLI Args ---
|
||||
parser = argparse.ArgumentParser(description="Launch online search server.")
|
||||
parser.add_argument('--api_key', type=str, required=True, help="API key for Google search")
|
||||
parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search")
|
||||
parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query")
|
||||
parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# --- Config ---
|
||||
class OnlineSearchConfig:
|
||||
def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False):
|
||||
self.topk = topk
|
||||
self.api_key = api_key
|
||||
self.cse_id = cse_id
|
||||
self.snippet_only = snippet_only
|
||||
|
||||
|
||||
# --- Utilities ---
|
||||
def parse_snippet(snippet: str) -> List[str]:
|
||||
segments = snippet.split("...")
|
||||
return [s.strip() for s in segments if len(s.strip().split()) > 5]
|
||||
|
||||
|
||||
def sanitize_search_query(query: str) -> str:
|
||||
# Remove or replace special characters that might cause issues.
|
||||
# This is a basic example; you might need to add more characters or patterns.
|
||||
sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces.
|
||||
sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces.
|
||||
sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces.
|
||||
|
||||
return sanitized_query
|
||||
|
||||
|
||||
def filter_links(search_results: List[Dict]) -> List[str]:
|
||||
links = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
if "mime" in item:
|
||||
continue
|
||||
ext = os.path.splitext(item["link"])[1]
|
||||
if ext in ["", ".html", ".htm", ".shtml"]:
|
||||
links.append(item["link"])
|
||||
return links
|
||||
|
||||
|
||||
async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
|
||||
user_agents = [
|
||||
"Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
|
||||
"Mozilla/5.0 AppleWebKit/537.36...",
|
||||
"Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
|
||||
]
|
||||
headers = {"User-Agent": random.choice(user_agents)}
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
raw = await response.read()
|
||||
detected = chardet.detect(raw)
|
||||
encoding = detected["encoding"] or "utf-8"
|
||||
return raw.decode(encoding, errors="ignore")
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError):
|
||||
return ""
|
||||
|
||||
|
||||
async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
|
||||
semaphore = asyncio.Semaphore(limit)
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
tasks = [fetch(session, url, semaphore) for url in urls]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# --- Search Engine ---
|
||||
class OnlineSearchEngine:
|
||||
def __init__(self, config: OnlineSearchConfig):
|
||||
self.config = config
|
||||
|
||||
def collect_context(self, snippet: str, doc: str) -> str:
|
||||
snippets = parse_snippet(snippet)
|
||||
ctx_paras = []
|
||||
|
||||
for s in snippets:
|
||||
pos = doc.replace("\n", " ").find(s)
|
||||
if pos == -1:
|
||||
continue
|
||||
sta = pos
|
||||
while sta > 0 and doc[sta] != "\n":
|
||||
sta -= 1
|
||||
end = pos + len(s)
|
||||
while end < len(doc) and doc[end] != "\n":
|
||||
end += 1
|
||||
para = doc[sta:end].strip()
|
||||
if para not in ctx_paras:
|
||||
ctx_paras.append(para)
|
||||
|
||||
return "\n".join(ctx_paras)
|
||||
|
||||
def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]:
|
||||
links = filter_links(search_results)
|
||||
contents = asyncio.run(fetch_all(links))
|
||||
content_dict = {}
|
||||
for html, link in zip(contents, links):
|
||||
soup = bs4.BeautifulSoup(html, "html.parser")
|
||||
text = "\n".join([p.get_text() for p in soup.find_all("p")])
|
||||
content_dict[link] = text
|
||||
return content_dict
|
||||
|
||||
def search(self, search_term: str, num_iter: int = 1) -> List[Dict]:
|
||||
service = build('customsearch', 'v1', developerKey=self.config.api_key)
|
||||
results = []
|
||||
sanitize_search_term = sanitize_search_query(search_term)
|
||||
if search_term.isspace():
|
||||
return results
|
||||
res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute()
|
||||
results.append(res)
|
||||
|
||||
for _ in range(num_iter - 1):
|
||||
if 'nextPage' not in res.get('queries', {}):
|
||||
break
|
||||
start_idx = res['queries']['nextPage'][0]['startIndex']
|
||||
res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute()
|
||||
results.append(res)
|
||||
|
||||
return results
|
||||
|
||||
def batch_search(self, queries: List[str]) -> List[List[str]]:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
return list(executor.map(self._retrieve_context, queries))
|
||||
|
||||
def _retrieve_context(self, query: str) -> List[str]:
|
||||
|
||||
if self.config.snippet_only:
|
||||
search_results = self.search(query)
|
||||
contexts = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
title = item.get("title", "")
|
||||
context = ' '.join(parse_snippet(item.get("snippet", "")))
|
||||
if title != "" or context != "":
|
||||
title = "No title." if not title else title
|
||||
context = "No snippet available." if not context else context
|
||||
contexts.append({
|
||||
'document': {"contents": f'\"{title}\"\n{context}'},
|
||||
})
|
||||
else:
|
||||
content_dict = self.fetch_web_content(search_results)
|
||||
contexts = []
|
||||
for result in search_results:
|
||||
for item in result.get("items", []):
|
||||
link = item["link"]
|
||||
title = item.get("title", "")
|
||||
snippet = item.get("snippet", "")
|
||||
if link in content_dict:
|
||||
context = self.collect_context(snippet, content_dict[link])
|
||||
if title != "" or context != "":
|
||||
title = "No title." if not title else title
|
||||
context = "No snippet available." if not context else context
|
||||
contexts.append({
|
||||
'document': {"contents": f'\"{title}\"\n{context}'},
|
||||
})
|
||||
|
||||
return contexts[:self.config.topk]
|
||||
|
||||
|
||||
# --- FastAPI App ---
|
||||
app = FastAPI(title="Online Search Proxy Server")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only)
|
||||
engine = OnlineSearchEngine(config)
|
||||
|
||||
@app.post("/retrieve")
|
||||
def search_endpoint(request: SearchRequest):
|
||||
results = engine.batch_search(request.queries)
|
||||
return {"result": results}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -20,7 +20,9 @@ parser = argparse.ArgumentParser(description="Launch the local faiss retriever."
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -335,11 +337,11 @@ app = FastAPI()
|
||||
# 1) Build a config (could also parse from arguments).
|
||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||
config = Config(
|
||||
retrieval_method = "e5", # or "dense"
|
||||
retrieval_method = args.retriever_name, # or "dense"
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.topk,
|
||||
faiss_gpu=True,
|
||||
faiss_gpu=args.faiss_gpu,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
|
||||
112
search_r1/search/serp_search_server.py
Normal file
112
search_r1/search/serp_search_server.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import requests
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch online search server.")
|
||||
parser.add_argument('--search_url', type=str, required=True,
|
||||
help="URL for search engine (e.g. https://serpapi.com/search)")
|
||||
parser.add_argument('--topk', type=int, default=3,
|
||||
help="Number of results to return per query")
|
||||
parser.add_argument('--serp_api_key', type=str, default=None,
|
||||
help="SerpAPI key for online search")
|
||||
parser.add_argument('--serp_engine', type=str, default="google",
|
||||
help="SerpAPI engine for online search")
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Config ---
|
||||
class OnlineSearchConfig:
|
||||
def __init__(
|
||||
self,
|
||||
search_url: str = "https://serpapi.com/search",
|
||||
topk: int = 3,
|
||||
serp_api_key: Optional[str] = None,
|
||||
serp_engine: Optional[str] = None,
|
||||
):
|
||||
self.search_url = search_url
|
||||
self.topk = topk
|
||||
self.serp_api_key = serp_api_key
|
||||
self.serp_engine = serp_engine
|
||||
|
||||
|
||||
# --- Online Search Wrapper ---
|
||||
class OnlineSearchEngine:
|
||||
def __init__(self, config: OnlineSearchConfig):
|
||||
self.config = config
|
||||
|
||||
def _search_query(self, query: str):
|
||||
params = {
|
||||
"engine": self.config.serp_engine,
|
||||
"q": query,
|
||||
"api_key": self.config.serp_api_key,
|
||||
}
|
||||
response = requests.get(self.config.search_url, params=params)
|
||||
return response.json()
|
||||
|
||||
def batch_search(self, queries: List[str]):
|
||||
results = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for result in executor.map(self._search_query, queries):
|
||||
results.append(self._process_result(result))
|
||||
return results
|
||||
|
||||
def _process_result(self, search_result: Dict):
|
||||
results = []
|
||||
|
||||
answer_box = search_result.get('answer_box', {})
|
||||
if answer_box:
|
||||
title = answer_box.get('title', 'No title.')
|
||||
snippet = answer_box.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
organic_results = search_result.get('organic_results', [])
|
||||
for _, result in enumerate(organic_results[:self.config.topk]):
|
||||
title = result.get('title', 'No title.')
|
||||
snippet = result.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
related_results = search_result.get('related_questions', [])
|
||||
for _, result in enumerate(related_results[:self.config.topk]):
|
||||
title = result.get('question', 'No title.') # question is the title here
|
||||
snippet = result.get('snippet', 'No snippet available.')
|
||||
results.append({
|
||||
'document': {"contents": f'\"{title}\"\n{snippet}'},
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# --- FastAPI Setup ---
|
||||
app = FastAPI(title="Online Search Proxy Server")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
# Instantiate global config + engine
|
||||
config = OnlineSearchConfig(
|
||||
search_url=args.search_url,
|
||||
topk=args.topk,
|
||||
serp_api_key=args.serp_api_key,
|
||||
serp_engine=args.serp_engine,
|
||||
)
|
||||
engine = OnlineSearchEngine(config)
|
||||
|
||||
# --- Routes ---
|
||||
@app.post("/retrieve")
|
||||
def search_endpoint(request: SearchRequest):
|
||||
results = engine.batch_search(request.queries)
|
||||
return {"result": results}
|
||||
|
||||
## return {"result": List[List[{'document': {"id": xx, "content": "title" + \n + "content"}, 'score': xx}]]}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user