add reranker
This commit is contained in:
17
example/retriever/retrieval_launch_hierarchical.sh
Normal file
17
example/retriever/retrieval_launch_hierarchical.sh
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
file_path=/the/path/you/save/corpus
|
||||||
|
index_file=$file_path/e5_Flat.index
|
||||||
|
corpus_file=$file_path/wiki-18.jsonl
|
||||||
|
retriever_name=e5
|
||||||
|
retriever_path=intfloat/e5-base-v2
|
||||||
|
reranker_path=cross-encoder/ms-marco-MiniLM-L12-v2
|
||||||
|
|
||||||
|
python search_r1/search/retrieval_rerank_server.py --index_path $index_file \
|
||||||
|
--corpus_path $corpus_file \
|
||||||
|
--retrieval_topk 10 \
|
||||||
|
--retriever_name $retriever_name \
|
||||||
|
--retriever_model $retriever_path \
|
||||||
|
--faiss_gpu \
|
||||||
|
--reranking_topk 3 \
|
||||||
|
--reranker_model $reranker_path \
|
||||||
|
--reranker_batch_size 32
|
||||||
161
search_r1/search/rerank_server.py
Normal file
161
search_r1/search/rerank_server.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
import torch
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCrossEncoder:
|
||||||
|
def __init__(self, model, batch_size=32, device="cuda"):
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
def _passage_to_string(self, doc_item):
|
||||||
|
if "document" not in doc_item:
|
||||||
|
content = doc_item['contents']
|
||||||
|
else:
|
||||||
|
content = doc_item['document']['contents']
|
||||||
|
title = content.split("\n")[0]
|
||||||
|
text = "\n".join(content.split("\n")[1:])
|
||||||
|
|
||||||
|
return f"(Title: {title}) {text}"
|
||||||
|
|
||||||
|
def rerank(self,
|
||||||
|
queries: list[str],
|
||||||
|
documents: list[list[dict]]):
|
||||||
|
"""
|
||||||
|
Assume documents is a list of list of dicts, where each dict is a document with keys "id" and "contents".
|
||||||
|
This asumption is made to be consistent with the output of the retrieval server.
|
||||||
|
"""
|
||||||
|
assert len(queries) == len(documents)
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
qids = []
|
||||||
|
for qid, query in enumerate(queries):
|
||||||
|
for document in documents:
|
||||||
|
for doc_item in document:
|
||||||
|
doc = self._passage_to_string(doc_item)
|
||||||
|
pairs.append((query, doc))
|
||||||
|
qids.append(qid)
|
||||||
|
|
||||||
|
scores = self._predict(pairs)
|
||||||
|
query_to_doc_scores = defaultdict(list)
|
||||||
|
|
||||||
|
assert len(scores) == len(pairs) == len(qids)
|
||||||
|
for i in range(len(pairs)):
|
||||||
|
query, doc = pairs[i]
|
||||||
|
score = scores[i]
|
||||||
|
qid = qids[i]
|
||||||
|
query_to_doc_scores[qid].append((doc, score))
|
||||||
|
|
||||||
|
sorted_query_to_doc_scores = {}
|
||||||
|
for query, doc_scores in query_to_doc_scores.items():
|
||||||
|
sorted_query_to_doc_scores[query] = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return sorted_query_to_doc_scores
|
||||||
|
|
||||||
|
def _predict(self, pairs: list[tuple[str, str]]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, model_name_or_path, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformerCrossEncoder(BaseCrossEncoder):
|
||||||
|
def __init__(self, model, batch_size=32, device="cuda"):
|
||||||
|
super().__init__(model, batch_size, device)
|
||||||
|
|
||||||
|
def _predict(self, pairs: list[tuple[str, str]]):
|
||||||
|
scores = self.model.predict(pairs, batch_size=self.batch_size)
|
||||||
|
scores = scores.tolist() if isinstance(scores, torch.Tensor) or isinstance(scores, np.ndarray) else scores
|
||||||
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, model_name_or_path, **kwargs):
|
||||||
|
model = CrossEncoder(model_name_or_path)
|
||||||
|
return cls(model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankRequest(BaseModel):
|
||||||
|
queries: list[str]
|
||||||
|
documents: list[list[dict]]
|
||||||
|
rerank_topk: Optional[int] = None
|
||||||
|
return_scores: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RerankerArguments:
|
||||||
|
max_length: int = field(default=512)
|
||||||
|
rerank_topk: int = field(default=3)
|
||||||
|
rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2")
|
||||||
|
batch_size: int = field(default=32)
|
||||||
|
reranker_type: str = field(default="sentence_transformer")
|
||||||
|
|
||||||
|
def get_reranker(config):
|
||||||
|
if config.reranker_type == "sentence_transformer":
|
||||||
|
return SentenceTransformerCrossEncoder.load(
|
||||||
|
config.rerank_model_name_or_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reranker type: {config.reranker_type}")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.post("/rerank")
|
||||||
|
def rerank_endpoint(request: RerankRequest):
|
||||||
|
"""
|
||||||
|
Endpoint that accepts queries and performs retrieval.
|
||||||
|
Input format:
|
||||||
|
{
|
||||||
|
"queries": ["What is Python?", "Tell me about neural networks."],
|
||||||
|
"documents": [[doc_item_1, ..., doc_item_k], [doc_item_1, ..., doc_item_k]],
|
||||||
|
"rerank_topk": 3,
|
||||||
|
"return_scores": true
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not request.rerank_topk:
|
||||||
|
request.rerank_topk = config.rerank_topk # fallback to default
|
||||||
|
|
||||||
|
# Perform batch re reranking
|
||||||
|
# doc_scores already sorted by score
|
||||||
|
query_to_doc_scores = reranker.rerank(request.queries, request.documents)
|
||||||
|
|
||||||
|
# Format response
|
||||||
|
resp = []
|
||||||
|
for _, doc_scores in query_to_doc_scores.items():
|
||||||
|
doc_scores = doc_scores[:request.rerank_topk]
|
||||||
|
if request.return_scores:
|
||||||
|
combined = []
|
||||||
|
for doc, score in doc_scores:
|
||||||
|
combined.append({"document": doc, "score": score})
|
||||||
|
resp.append(combined)
|
||||||
|
else:
|
||||||
|
resp.append([doc for doc, _ in doc_scores])
|
||||||
|
return {"result": resp}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# 1) Build a config (could also parse from arguments).
|
||||||
|
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||||
|
parser = HfArgumentParser((RerankerArguments))
|
||||||
|
config = parser.parse_args_into_dataclasses()[0]
|
||||||
|
|
||||||
|
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||||
|
reranker = get_reranker(config)
|
||||||
|
|
||||||
|
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=6980)
|
||||||
123
search_r1/search/retrieval_rerank_server.py
Normal file
123
search_r1/search/retrieval_rerank_server.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# pip install -U sentence-transformers
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
from retrieval_server import get_retriever, Config as RetrieverConfig
|
||||||
|
from rerank_server import SentenceTransformerCrossEncoder
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
def convert_title_format(text):
|
||||||
|
# Use regex to extract the title and the content
|
||||||
|
match = re.match(r'\(Title:\s*([^)]+)\)\s*(.+)', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
title, content = match.groups()
|
||||||
|
return f'\"{title}\"\n{content}'
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# ----------- Combined Request Schema -----------
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
queries: List[str]
|
||||||
|
topk_retrieval: Optional[int] = 10
|
||||||
|
topk_rerank: Optional[int] = 3
|
||||||
|
return_scores: bool = False
|
||||||
|
|
||||||
|
# ----------- Reranker Config Schema -----------
|
||||||
|
@dataclass
|
||||||
|
class RerankerArguments:
|
||||||
|
max_length: int = field(default=512)
|
||||||
|
rerank_topk: int = field(default=3)
|
||||||
|
rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2")
|
||||||
|
batch_size: int = field(default=32)
|
||||||
|
reranker_type: str = field(default="sentence_transformer")
|
||||||
|
|
||||||
|
def get_reranker(config):
|
||||||
|
if config.reranker_type == "sentence_transformer":
|
||||||
|
return SentenceTransformerCrossEncoder.load(
|
||||||
|
config.rerank_model_name_or_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reranker type: {config.reranker_type}")
|
||||||
|
|
||||||
|
# ----------- Endpoint -----------
|
||||||
|
@app.post("/retrieve")
|
||||||
|
def search_endpoint(request: SearchRequest):
|
||||||
|
# Step 1: Retrieve documents
|
||||||
|
retrieved_docs = retriever.batch_search(
|
||||||
|
query_list=request.queries,
|
||||||
|
num=request.topk_retrieval,
|
||||||
|
return_score=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Rerank
|
||||||
|
reranked = reranker.rerank(request.queries, retrieved_docs)
|
||||||
|
|
||||||
|
# Step 3: Format response
|
||||||
|
response = []
|
||||||
|
for i, doc_scores in reranked.items():
|
||||||
|
doc_scores = doc_scores[:request.topk_rerank]
|
||||||
|
if request.return_scores:
|
||||||
|
combined = []
|
||||||
|
for doc, score in doc_scores:
|
||||||
|
combined.append({"document": convert_title_format(doc), "score": score})
|
||||||
|
response.append(combined)
|
||||||
|
else:
|
||||||
|
response.append([convert_title_format(doc) for doc, _ in doc_scores])
|
||||||
|
|
||||||
|
return {"result": response}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||||
|
# retriever
|
||||||
|
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||||
|
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||||
|
parser.add_argument("--retrieval_topk", type=int, default=10, help="Number of retrieved passages for one query.")
|
||||||
|
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||||
|
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||||
|
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||||
|
# reranker
|
||||||
|
parser.add_argument("--reranking_topk", type=int, default=3, help="Number of reranked passages for one query.")
|
||||||
|
parser.add_argument("--reranker_model", type=str, default="cross-encoder/ms-marco-MiniLM-L12-v2", help="Path of the reranker model.")
|
||||||
|
parser.add_argument("--reranker_batch_size", type=int, default=32, help="Batch size for the reranker inference.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ----------- Load Retriever and Reranker -----------
|
||||||
|
retriever_config = RetrieverConfig(
|
||||||
|
retrieval_method = args.retriever_name,
|
||||||
|
index_path=args.index_path,
|
||||||
|
corpus_path=args.corpus_path,
|
||||||
|
retrieval_topk=args.retrieval_topk,
|
||||||
|
faiss_gpu=args.faiss_gpu,
|
||||||
|
retrieval_model_path=args.retriever_model,
|
||||||
|
retrieval_pooling_method="mean",
|
||||||
|
retrieval_query_max_length=256,
|
||||||
|
retrieval_use_fp16=True,
|
||||||
|
retrieval_batch_size=512,
|
||||||
|
)
|
||||||
|
retriever = get_retriever(retriever_config)
|
||||||
|
|
||||||
|
reranker_config = RerankerArguments(
|
||||||
|
rerank_topk = args.reranking_topk,
|
||||||
|
rerank_model_name_or_path = args.reranker_model,
|
||||||
|
batch_size = args.reranker_batch_size,
|
||||||
|
)
|
||||||
|
reranker = get_reranker(reranker_config)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
@@ -15,17 +15,6 @@ import uvicorn
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
|
||||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
|
||||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
|
||||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
|
||||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
|
||||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
|
||||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def load_corpus(corpus_path: str):
|
def load_corpus(corpus_path: str):
|
||||||
corpus = datasets.load_dataset(
|
corpus = datasets.load_dataset(
|
||||||
'json',
|
'json',
|
||||||
@@ -334,24 +323,6 @@ class QueryRequest(BaseModel):
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# 1) Build a config (could also parse from arguments).
|
|
||||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
|
||||||
config = Config(
|
|
||||||
retrieval_method = args.retriever_name, # or "dense"
|
|
||||||
index_path=args.index_path,
|
|
||||||
corpus_path=args.corpus_path,
|
|
||||||
retrieval_topk=args.topk,
|
|
||||||
faiss_gpu=args.faiss_gpu,
|
|
||||||
retrieval_model_path=args.retriever_model,
|
|
||||||
retrieval_pooling_method="mean",
|
|
||||||
retrieval_query_max_length=256,
|
|
||||||
retrieval_use_fp16=True,
|
|
||||||
retrieval_batch_size=512,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Instantiate a global retriever so it is loaded once and reused.
|
|
||||||
retriever = get_retriever(config)
|
|
||||||
|
|
||||||
@app.post("/retrieve")
|
@app.post("/retrieve")
|
||||||
def retrieve_endpoint(request: QueryRequest):
|
def retrieve_endpoint(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
@@ -388,5 +359,34 @@ def retrieve_endpoint(request: QueryRequest):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||||
|
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||||
|
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||||
|
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||||
|
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||||
|
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||||
|
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 1) Build a config (could also parse from arguments).
|
||||||
|
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||||
|
config = Config(
|
||||||
|
retrieval_method = args.retriever_name, # or "dense"
|
||||||
|
index_path=args.index_path,
|
||||||
|
corpus_path=args.corpus_path,
|
||||||
|
retrieval_topk=args.topk,
|
||||||
|
faiss_gpu=args.faiss_gpu,
|
||||||
|
retrieval_model_path=args.retriever_model,
|
||||||
|
retrieval_pooling_method="mean",
|
||||||
|
retrieval_query_max_length=256,
|
||||||
|
retrieval_use_fp16=True,
|
||||||
|
retrieval_batch_size=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||||
|
retriever = get_retriever(config)
|
||||||
|
|
||||||
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|||||||
Reference in New Issue
Block a user