diff --git a/example/retriever/retrieval_launch_hierarchical.sh b/example/retriever/retrieval_launch_hierarchical.sh new file mode 100644 index 0000000..7536b80 --- /dev/null +++ b/example/retriever/retrieval_launch_hierarchical.sh @@ -0,0 +1,17 @@ + +file_path=/the/path/you/save/corpus +index_file=$file_path/e5_Flat.index +corpus_file=$file_path/wiki-18.jsonl +retriever_name=e5 +retriever_path=intfloat/e5-base-v2 +reranker_path=cross-encoder/ms-marco-MiniLM-L12-v2 + +python search_r1/search/retrieval_rerank_server.py --index_path $index_file \ + --corpus_path $corpus_file \ + --retrieval_topk 10 \ + --retriever_name $retriever_name \ + --retriever_model $retriever_path \ + --faiss_gpu \ + --reranking_topk 3 \ + --reranker_model $reranker_path \ + --reranker_batch_size 32 diff --git a/search_r1/search/rerank_server.py b/search_r1/search/rerank_server.py new file mode 100644 index 0000000..9edabe8 --- /dev/null +++ b/search_r1/search/rerank_server.py @@ -0,0 +1,161 @@ +import argparse +from collections import defaultdict +from typing import Optional +from dataclasses import dataclass, field + +from sentence_transformers import CrossEncoder +import torch +from transformers import HfArgumentParser +import numpy as np + +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel + + +class BaseCrossEncoder: + def __init__(self, model, batch_size=32, device="cuda"): + self.model = model + self.batch_size = batch_size + self.model.to(device) + + def _passage_to_string(self, doc_item): + if "document" not in doc_item: + content = doc_item['contents'] + else: + content = doc_item['document']['contents'] + title = content.split("\n")[0] + text = "\n".join(content.split("\n")[1:]) + + return f"(Title: {title}) {text}" + + def rerank(self, + queries: list[str], + documents: list[list[dict]]): + """ + Assume documents is a list of list of dicts, where each dict is a document with keys "id" and "contents". + This asumption is made to be consistent with the output of the retrieval server. + """ + assert len(queries) == len(documents) + + pairs = [] + qids = [] + for qid, query in enumerate(queries): + for document in documents: + for doc_item in document: + doc = self._passage_to_string(doc_item) + pairs.append((query, doc)) + qids.append(qid) + + scores = self._predict(pairs) + query_to_doc_scores = defaultdict(list) + + assert len(scores) == len(pairs) == len(qids) + for i in range(len(pairs)): + query, doc = pairs[i] + score = scores[i] + qid = qids[i] + query_to_doc_scores[qid].append((doc, score)) + + sorted_query_to_doc_scores = {} + for query, doc_scores in query_to_doc_scores.items(): + sorted_query_to_doc_scores[query] = sorted(doc_scores, key=lambda x: x[1], reverse=True) + + return sorted_query_to_doc_scores + + def _predict(self, pairs: list[tuple[str, str]]): + raise NotImplementedError + + @classmethod + def load(cls, model_name_or_path, **kwargs): + raise NotImplementedError + + +class SentenceTransformerCrossEncoder(BaseCrossEncoder): + def __init__(self, model, batch_size=32, device="cuda"): + super().__init__(model, batch_size, device) + + def _predict(self, pairs: list[tuple[str, str]]): + scores = self.model.predict(pairs, batch_size=self.batch_size) + scores = scores.tolist() if isinstance(scores, torch.Tensor) or isinstance(scores, np.ndarray) else scores + return scores + + @classmethod + def load(cls, model_name_or_path, **kwargs): + model = CrossEncoder(model_name_or_path) + return cls(model, **kwargs) + + +class RerankRequest(BaseModel): + queries: list[str] + documents: list[list[dict]] + rerank_topk: Optional[int] = None + return_scores: bool = False + + +@dataclass +class RerankerArguments: + max_length: int = field(default=512) + rerank_topk: int = field(default=3) + rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2") + batch_size: int = field(default=32) + reranker_type: str = field(default="sentence_transformer") + +def get_reranker(config): + if config.reranker_type == "sentence_transformer": + return SentenceTransformerCrossEncoder.load( + config.rerank_model_name_or_path, + batch_size=config.batch_size, + device="cuda" if torch.cuda.is_available() else "cpu" + ) + else: + raise ValueError(f"Unknown reranker type: {config.reranker_type}") + + +app = FastAPI() + +@app.post("/rerank") +def rerank_endpoint(request: RerankRequest): + """ + Endpoint that accepts queries and performs retrieval. + Input format: + { + "queries": ["What is Python?", "Tell me about neural networks."], + "documents": [[doc_item_1, ..., doc_item_k], [doc_item_1, ..., doc_item_k]], + "rerank_topk": 3, + "return_scores": true + } + """ + if not request.rerank_topk: + request.rerank_topk = config.rerank_topk # fallback to default + + # Perform batch re reranking + # doc_scores already sorted by score + query_to_doc_scores = reranker.rerank(request.queries, request.documents) + + # Format response + resp = [] + for _, doc_scores in query_to_doc_scores.items(): + doc_scores = doc_scores[:request.rerank_topk] + if request.return_scores: + combined = [] + for doc, score in doc_scores: + combined.append({"document": doc, "score": score}) + resp.append(combined) + else: + resp.append([doc for doc, _ in doc_scores]) + return {"result": resp} + + +if __name__ == "__main__": + + # 1) Build a config (could also parse from arguments). + # In real usage, you'd parse your CLI arguments or environment variables. + parser = HfArgumentParser((RerankerArguments)) + config = parser.parse_args_into_dataclasses()[0] + + # 2) Instantiate a global retriever so it is loaded once and reused. + reranker = get_reranker(config) + + # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 + uvicorn.run(app, host="0.0.0.0", port=6980) diff --git a/search_r1/search/retrieval_rerank_server.py b/search_r1/search/retrieval_rerank_server.py new file mode 100644 index 0000000..a9e14f7 --- /dev/null +++ b/search_r1/search/retrieval_rerank_server.py @@ -0,0 +1,123 @@ +# pip install -U sentence-transformers +import os +import re +import argparse +from dataclasses import dataclass, field +from typing import List, Optional +from collections import defaultdict + +import torch +import numpy as np +from fastapi import FastAPI +from pydantic import BaseModel +from sentence_transformers import CrossEncoder + +from retrieval_server import get_retriever, Config as RetrieverConfig +from rerank_server import SentenceTransformerCrossEncoder + +app = FastAPI() + +def convert_title_format(text): + # Use regex to extract the title and the content + match = re.match(r'\(Title:\s*([^)]+)\)\s*(.+)', text, re.DOTALL) + if match: + title, content = match.groups() + return f'\"{title}\"\n{content}' + else: + return text + +# ----------- Combined Request Schema ----------- +class SearchRequest(BaseModel): + queries: List[str] + topk_retrieval: Optional[int] = 10 + topk_rerank: Optional[int] = 3 + return_scores: bool = False + +# ----------- Reranker Config Schema ----------- +@dataclass +class RerankerArguments: + max_length: int = field(default=512) + rerank_topk: int = field(default=3) + rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2") + batch_size: int = field(default=32) + reranker_type: str = field(default="sentence_transformer") + +def get_reranker(config): + if config.reranker_type == "sentence_transformer": + return SentenceTransformerCrossEncoder.load( + config.rerank_model_name_or_path, + batch_size=config.batch_size, + device="cuda" if torch.cuda.is_available() else "cpu" + ) + else: + raise ValueError(f"Unknown reranker type: {config.reranker_type}") + +# ----------- Endpoint ----------- +@app.post("/retrieve") +def search_endpoint(request: SearchRequest): + # Step 1: Retrieve documents + retrieved_docs = retriever.batch_search( + query_list=request.queries, + num=request.topk_retrieval, + return_score=False + ) + + # Step 2: Rerank + reranked = reranker.rerank(request.queries, retrieved_docs) + + # Step 3: Format response + response = [] + for i, doc_scores in reranked.items(): + doc_scores = doc_scores[:request.topk_rerank] + if request.return_scores: + combined = [] + for doc, score in doc_scores: + combined.append({"document": convert_title_format(doc), "score": score}) + response.append(combined) + else: + response.append([convert_title_format(doc) for doc, _ in doc_scores]) + + return {"result": response} + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") + # retriever + parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.") + parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.") + parser.add_argument("--retrieval_topk", type=int, default=10, help="Number of retrieved passages for one query.") + parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") + parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.") + parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation') + # reranker + parser.add_argument("--reranking_topk", type=int, default=3, help="Number of reranked passages for one query.") + parser.add_argument("--reranker_model", type=str, default="cross-encoder/ms-marco-MiniLM-L12-v2", help="Path of the reranker model.") + parser.add_argument("--reranker_batch_size", type=int, default=32, help="Batch size for the reranker inference.") + + args = parser.parse_args() + + # ----------- Load Retriever and Reranker ----------- + retriever_config = RetrieverConfig( + retrieval_method = args.retriever_name, + index_path=args.index_path, + corpus_path=args.corpus_path, + retrieval_topk=args.retrieval_topk, + faiss_gpu=args.faiss_gpu, + retrieval_model_path=args.retriever_model, + retrieval_pooling_method="mean", + retrieval_query_max_length=256, + retrieval_use_fp16=True, + retrieval_batch_size=512, + ) + retriever = get_retriever(retriever_config) + + reranker_config = RerankerArguments( + rerank_topk = args.reranking_topk, + rerank_model_name_or_path = args.reranker_model, + batch_size = args.reranker_batch_size, + ) + reranker = get_reranker(reranker_config) + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/search_r1/search/retrieval_server.py b/search_r1/search/retrieval_server.py index a58f9cf..f396989 100644 --- a/search_r1/search/retrieval_server.py +++ b/search_r1/search/retrieval_server.py @@ -15,17 +15,6 @@ import uvicorn from fastapi import FastAPI from pydantic import BaseModel - -parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") -parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.") -parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.") -parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") -parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") -parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.") -parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation') - -args = parser.parse_args() - def load_corpus(corpus_path: str): corpus = datasets.load_dataset( 'json', @@ -334,24 +323,6 @@ class QueryRequest(BaseModel): app = FastAPI() -# 1) Build a config (could also parse from arguments). -# In real usage, you'd parse your CLI arguments or environment variables. -config = Config( - retrieval_method = args.retriever_name, # or "dense" - index_path=args.index_path, - corpus_path=args.corpus_path, - retrieval_topk=args.topk, - faiss_gpu=args.faiss_gpu, - retrieval_model_path=args.retriever_model, - retrieval_pooling_method="mean", - retrieval_query_max_length=256, - retrieval_use_fp16=True, - retrieval_batch_size=512, -) - -# 2) Instantiate a global retriever so it is loaded once and reused. -retriever = get_retriever(config) - @app.post("/retrieve") def retrieve_endpoint(request: QueryRequest): """ @@ -388,5 +359,34 @@ def retrieve_endpoint(request: QueryRequest): if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") + parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.") + parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.") + parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") + parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") + parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.") + parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation') + + args = parser.parse_args() + + # 1) Build a config (could also parse from arguments). + # In real usage, you'd parse your CLI arguments or environment variables. + config = Config( + retrieval_method = args.retriever_name, # or "dense" + index_path=args.index_path, + corpus_path=args.corpus_path, + retrieval_topk=args.topk, + faiss_gpu=args.faiss_gpu, + retrieval_model_path=args.retriever_model, + retrieval_pooling_method="mean", + retrieval_query_max_length=256, + retrieval_use_fp16=True, + retrieval_batch_size=512, + ) + + # 2) Instantiate a global retriever so it is loaded once and reused. + retriever = get_retriever(config) + # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 uvicorn.run(app, host="0.0.0.0", port=8000)