add local sparse retriever, ann dense retriever and online search engine

This commit is contained in:
PeterGriffinJin
2025-04-07 18:20:43 +00:00
parent 0b26e614f7
commit ba152349fd
8 changed files with 470 additions and 7 deletions

View File

@@ -50,7 +50,7 @@ conda activate retriever
# we recommend installing torch with conda for faiss-gpu
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers datasets
pip install transformers datasets pyserini
## install the gpu version faiss to guarantee efficient RL rollout
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
@@ -163,6 +163,8 @@ You can change ```retriever_name``` and ```retriever_model``` to your interested
## Use your own search engine
Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md).
The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline.
The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
@@ -170,7 +172,7 @@ The LLM can call the search engine by calling the search API (e.g., "http://127.
You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
## To do
- Support google search / bing search / brave search API and others.
- Support google search / bing search / brave search API and others. ✔️
- Support LoRA tuning.
- Support supervised finetuning.
- Support off-the-shelf rerankers.

127
docs/retriever.md Normal file
View File

@@ -0,0 +1,127 @@
## Search Engine
In this document, we provide examples of how to launch different retrievers, including local sparse retriever (e.g., BM25), local dense retriever (e.g., e5) and online search engine.
For local retrievers, we use [wiki-18](https://huggingface.co/datasets/PeterJinGo/wiki-18-corpus) corpus as an example and the corpus indexing can be found at [bm25](https://huggingface.co/datasets/PeterJinGo/wiki-18-bm25-index), [e5-flat](https://huggingface.co/datasets/PeterJinGo/wiki-18-e5-index), [e5-HNSW64](PeterJinGo/wiki-18-e5-index-HNSW64).
### How to choose the retriever?
- If you have a private or domain-specific corpus, choose **local retriever**.
- If there is no high quality embedding-based retrievers (dense retrievers) in your domain, choose **sparse local retriever** (e.g., BM25).
- Otherwise choose **dense local retriever**.
- If you do not have sufficent GPUs to conduct exact dense embedding matching, choose **ANN indexing** on CPUs.
- If you have sufficient GPUs, choose **flat indexing** on GPUs.
- If you want to train a general LLM search agent and have enough funding, choose **online search engine** (e.g., [SerpAPI](https://serpapi.com/)).
- If you have a domain specific online search engine (e.g., PubMed search), you can refer to [link](https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/serp_search_server.py) to integrate it to Search-R1 by yourself.
### Local Sparse Retriever
Sparse retriever (e.g., bm25) is a traditional method. The retrieval process is very efficient and no GPUs are needed. However, it may not be as accurate as dense retrievers in some specific domain.
(1) Download the indexing.
```bash
save_path=/your/path/to/save
huggingface-cli download PeterJinGo/wiki-18-bm25-index --repo-type dataset --local-dir $save_path
```
(2) Launch a local BM25 retriever server.
```bash
conda activate retriever
index_file=$save_path/bm25
corpus_file=$save_path/wiki-18.jsonl
retriever_name=bm25
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name
```
### Local Dense Retriever
You can also adopt some off-the-shelf dense retrievers, e.g., e5. These models are much stronger than sparse retriever in some specific domains.
If you have sufficient GPU, we would recommend the flat indexing variant below, otherwise you can adopt the ANN variant.
#### Flat indexing
Flat indexing conducts exact embedding match, which is slow but very accurate. To make it efficient enough to support online RL, we would recommend enable **GPU** usage by ```--faiss_gpu```.
(1) Download the indexing and corpus.
```bash
save_path=/the/path/to/save
python scripts/download.py --save_path $save_path
cat $save_path/part_* > $save_path/e5_Flat.index
gzip -d $save_path/wiki-18.jsonl.gz
```
(2) Launch a local flat e5 retriever server.
```bash
conda activate retriever
index_file=$save_path/e5_Flat.index
corpus_file=$save_path/wiki-18.jsonl
retriever_name=e5
retriever_path=intfloat/e5-base-v2
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name --retriever_model $retriever_path --faiss_gpu
```
#### ANN indexing (HNSW64)
To improve the search efficient with only **CPU**, you can adopt approximate nearest neighbor (ANN) indexing, e.g., with HNSW64.
It is very efficient, but may not be as accurate as flat indexing, especially when the number of retrieved passages is small.
(1) Download the indexing.
```bash
save_path=/the/path/to/save
huggingface-cli download PeterJinGo/wiki-18-e5-index-HNSW64 --repo-type dataset --local-dir $save_path
cat $save_path/part_* > $save_path/e5_HNSW64.index
```
(2) Launch a local BM25 retriever server.
```bash
conda activate retriever
index_file=$save_path/e5_HNSW64.index
corpus_file=$save_path/wiki-18.jsonl
retriever_name=e5
retriever_path=intfloat/e5-base-v2
python search_r1/search/retrieval_server.py --index_path $index_file --corpus_path $corpus_file --topk 3 --retriever_name $retriever_name --retriever_model $retriever_path
```
### Online Search Engine
We support both [Google Search API](https://developers.google.com/custom-search/v1/overview) and [SerpAPI](https://serpapi.com/). We would recommend [SerpAPI](https://serpapi.com/) since it integrates multiple online search engine APIs (including Google, Bing, Baidu, etc) and does not have a monthly quota limitation ([Google Search API](https://developers.google.com/custom-search/v1/overview) has a hard 10k monthly quota, which is not sufficient to fulfill online LLM RL training).
#### SerAPI online search server
```bash
search_url=https://serpapi.com/search
serp_api_key="" # put your serp api key here (https://serpapi.com/)
python search_r1/search/online_search_server.py --search_url $search_url --topk 3 --serp_api_key $serp_api_key
```
#### Google online search server
```bash
api_key="" # put your google custom API key here (https://developers.google.com/custom-search/v1/overview)
cse_id="" # put your google cse API key here (https://developers.google.com/custom-search/v1/overview)
python search_r1/search/internal_google_server.py --api_key $api_key --topk 5 --cse_id $cse_id --snippet_only
```

View File

@@ -0,0 +1,8 @@
api_key="" # put your google custom API key here (https://developers.google.com/custom-search/v1/overview)
cse_id="" # put your google cse API key here (https://developers.google.com/custom-search/v1/overview)
python search_r1/search/internal_google_server.py --api_key $api_key \
--topk 5 \
--cse_id $cse_id \
--snippet_only

View File

@@ -0,0 +1,7 @@
search_url=https://serpapi.com/search
serp_api_key="" # put your serp api key here (https://serpapi.com/)
python search_r1/search/online_search_server.py --search_url $search_url \
--topk 3 \
--serp_api_key $serp_api_key

View File

@@ -2,9 +2,12 @@
file_path=/the/path/you/save/corpus
index_file=$file_path/e5_Flat.index
corpus_file=$file_path/wiki-18.jsonl
retriever=intfloat/e5-base-v2
retriever_name=e5
retriever_path=intfloat/e5-base-v2
python search_r1/search/retrieval_server.py --index_path $index_file \
--corpus_path $corpus_file \
--topk 3 \
--retriever_model $retriever
--retriever_name $retriever_name \
--retriever_model $retriever_path \
--faiss_gpu

View File

@@ -0,0 +1,202 @@
import os
import re
import requests
import argparse
import asyncio
import random
from typing import List, Optional, Dict
from concurrent.futures import ThreadPoolExecutor
import chardet
import aiohttp
import bs4
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from googleapiclient.discovery import build
# --- CLI Args ---
parser = argparse.ArgumentParser(description="Launch online search server.")
parser.add_argument('--api_key', type=str, required=True, help="API key for Google search")
parser.add_argument('--cse_id', type=str, required=True, help="CSE ID for Google search")
parser.add_argument('--topk', type=int, default=3, help="Number of results to return per query")
parser.add_argument('--snippet_only', action='store_true', help="If set, only return snippets; otherwise, return full context.")
args = parser.parse_args()
# --- Config ---
class OnlineSearchConfig:
def __init__(self, topk: int = 3, api_key: Optional[str] = None, cse_id: Optional[str] = None, snippet_only: bool = False):
self.topk = topk
self.api_key = api_key
self.cse_id = cse_id
self.snippet_only = snippet_only
# --- Utilities ---
def parse_snippet(snippet: str) -> List[str]:
segments = snippet.split("...")
return [s.strip() for s in segments if len(s.strip().split()) > 5]
def sanitize_search_query(query: str) -> str:
# Remove or replace special characters that might cause issues.
# This is a basic example; you might need to add more characters or patterns.
sanitized_query = re.sub(r'[^\w\s]', ' ', query) # Replace non-alphanumeric and non-whitespace with spaces.
sanitized_query = re.sub(r'[\t\r\f\v\n]', ' ', sanitized_query) # replace tab, return, formfeed, vertical tab with spaces.
sanitized_query = re.sub(r'\s+', ' ', sanitized_query).strip() #remove duplicate spaces, and trailing/leading spaces.
return sanitized_query
def filter_links(search_results: List[Dict]) -> List[str]:
links = []
for result in search_results:
for item in result.get("items", []):
if "mime" in item:
continue
ext = os.path.splitext(item["link"])[1]
if ext in ["", ".html", ".htm", ".shtml"]:
links.append(item["link"])
return links
async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
user_agents = [
"Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
"Mozilla/5.0 AppleWebKit/537.36...",
"Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
]
headers = {"User-Agent": random.choice(user_agents)}
async with semaphore:
try:
async with session.get(url, headers=headers) as response:
raw = await response.read()
detected = chardet.detect(raw)
encoding = detected["encoding"] or "utf-8"
return raw.decode(encoding, errors="ignore")
except (aiohttp.ClientError, asyncio.TimeoutError):
return ""
async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
semaphore = asyncio.Semaphore(limit)
timeout = aiohttp.ClientTimeout(total=5)
connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
tasks = [fetch(session, url, semaphore) for url in urls]
return await asyncio.gather(*tasks)
# --- Search Engine ---
class OnlineSearchEngine:
def __init__(self, config: OnlineSearchConfig):
self.config = config
def collect_context(self, snippet: str, doc: str) -> str:
snippets = parse_snippet(snippet)
ctx_paras = []
for s in snippets:
pos = doc.replace("\n", " ").find(s)
if pos == -1:
continue
sta = pos
while sta > 0 and doc[sta] != "\n":
sta -= 1
end = pos + len(s)
while end < len(doc) and doc[end] != "\n":
end += 1
para = doc[sta:end].strip()
if para not in ctx_paras:
ctx_paras.append(para)
return "\n".join(ctx_paras)
def fetch_web_content(self, search_results: List[Dict]) -> Dict[str, str]:
links = filter_links(search_results)
contents = asyncio.run(fetch_all(links))
content_dict = {}
for html, link in zip(contents, links):
soup = bs4.BeautifulSoup(html, "html.parser")
text = "\n".join([p.get_text() for p in soup.find_all("p")])
content_dict[link] = text
return content_dict
def search(self, search_term: str, num_iter: int = 1) -> List[Dict]:
service = build('customsearch', 'v1', developerKey=self.config.api_key)
results = []
sanitize_search_term = sanitize_search_query(search_term)
if search_term.isspace():
return results
res = service.cse().list(q=sanitize_search_term, cx=self.config.cse_id).execute()
results.append(res)
for _ in range(num_iter - 1):
if 'nextPage' not in res.get('queries', {}):
break
start_idx = res['queries']['nextPage'][0]['startIndex']
res = service.cse().list(q=search_term, cx=self.config.cse_id, start=start_idx).execute()
results.append(res)
return results
def batch_search(self, queries: List[str]) -> List[List[str]]:
with ThreadPoolExecutor() as executor:
return list(executor.map(self._retrieve_context, queries))
def _retrieve_context(self, query: str) -> List[str]:
if self.config.snippet_only:
search_results = self.search(query)
contexts = []
for result in search_results:
for item in result.get("items", []):
title = item.get("title", "")
context = ' '.join(parse_snippet(item.get("snippet", "")))
if title != "" or context != "":
title = "No title." if not title else title
context = "No snippet available." if not context else context
contexts.append({
'document': {"contents": f'\"{title}\"\n{context}'},
})
else:
content_dict = self.fetch_web_content(search_results)
contexts = []
for result in search_results:
for item in result.get("items", []):
link = item["link"]
title = item.get("title", "")
snippet = item.get("snippet", "")
if link in content_dict:
context = self.collect_context(snippet, content_dict[link])
if title != "" or context != "":
title = "No title." if not title else title
context = "No snippet available." if not context else context
contexts.append({
'document': {"contents": f'\"{title}\"\n{context}'},
})
return contexts[:self.config.topk]
# --- FastAPI App ---
app = FastAPI(title="Online Search Proxy Server")
class SearchRequest(BaseModel):
queries: List[str]
config = OnlineSearchConfig(api_key=args.api_key, cse_id=args.cse_id, topk=args.topk, snippet_only=args.snippet_only)
engine = OnlineSearchEngine(config)
@app.post("/retrieve")
def search_endpoint(request: SearchRequest):
results = engine.batch_search(request.queries)
return {"result": results}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -20,7 +20,9 @@ parser = argparse.ArgumentParser(description="Launch the local faiss retriever."
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
args = parser.parse_args()
@@ -335,11 +337,11 @@ app = FastAPI()
# 1) Build a config (could also parse from arguments).
# In real usage, you'd parse your CLI arguments or environment variables.
config = Config(
retrieval_method = "e5", # or "dense"
retrieval_method = args.retriever_name, # or "dense"
index_path=args.index_path,
corpus_path=args.corpus_path,
retrieval_topk=args.topk,
faiss_gpu=True,
faiss_gpu=args.faiss_gpu,
retrieval_model_path=args.retriever_model,
retrieval_pooling_method="mean",
retrieval_query_max_length=256,

View File

@@ -0,0 +1,112 @@
import os
import requests
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional, Dict
from concurrent.futures import ThreadPoolExecutor
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Launch online search server.")
parser.add_argument('--search_url', type=str, required=True,
help="URL for search engine (e.g. https://serpapi.com/search)")
parser.add_argument('--topk', type=int, default=3,
help="Number of results to return per query")
parser.add_argument('--serp_api_key', type=str, default=None,
help="SerpAPI key for online search")
parser.add_argument('--serp_engine', type=str, default="google",
help="SerpAPI engine for online search")
args = parser.parse_args()
# --- Config ---
class OnlineSearchConfig:
def __init__(
self,
search_url: str = "https://serpapi.com/search",
topk: int = 3,
serp_api_key: Optional[str] = None,
serp_engine: Optional[str] = None,
):
self.search_url = search_url
self.topk = topk
self.serp_api_key = serp_api_key
self.serp_engine = serp_engine
# --- Online Search Wrapper ---
class OnlineSearchEngine:
def __init__(self, config: OnlineSearchConfig):
self.config = config
def _search_query(self, query: str):
params = {
"engine": self.config.serp_engine,
"q": query,
"api_key": self.config.serp_api_key,
}
response = requests.get(self.config.search_url, params=params)
return response.json()
def batch_search(self, queries: List[str]):
results = []
with ThreadPoolExecutor() as executor:
for result in executor.map(self._search_query, queries):
results.append(self._process_result(result))
return results
def _process_result(self, search_result: Dict):
results = []
answer_box = search_result.get('answer_box', {})
if answer_box:
title = answer_box.get('title', 'No title.')
snippet = answer_box.get('snippet', 'No snippet available.')
results.append({
'document': {"contents": f'\"{title}\"\n{snippet}'},
})
organic_results = search_result.get('organic_results', [])
for _, result in enumerate(organic_results[:self.config.topk]):
title = result.get('title', 'No title.')
snippet = result.get('snippet', 'No snippet available.')
results.append({
'document': {"contents": f'\"{title}\"\n{snippet}'},
})
related_results = search_result.get('related_questions', [])
for _, result in enumerate(related_results[:self.config.topk]):
title = result.get('question', 'No title.') # question is the title here
snippet = result.get('snippet', 'No snippet available.')
results.append({
'document': {"contents": f'\"{title}\"\n{snippet}'},
})
return results
# --- FastAPI Setup ---
app = FastAPI(title="Online Search Proxy Server")
class SearchRequest(BaseModel):
queries: List[str]
# Instantiate global config + engine
config = OnlineSearchConfig(
search_url=args.search_url,
topk=args.topk,
serp_api_key=args.serp_api_key,
serp_engine=args.serp_engine,
)
engine = OnlineSearchEngine(config)
# --- Routes ---
@app.post("/retrieve")
def search_endpoint(request: SearchRequest):
results = engine.batch_search(request.queries)
return {"result": results}
## return {"result": List[List[{'document': {"id": xx, "content": "title" + \n + "content"}, 'score': xx}]]}
if __name__ == "__main__":
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
uvicorn.run(app, host="0.0.0.0", port=8000)