add reranker
This commit is contained in:
123
search_r1/search/retrieval_rerank_server.py
Normal file
123
search_r1/search/retrieval_rerank_server.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# pip install -U sentence-transformers
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
from retrieval_server import get_retriever, Config as RetrieverConfig
|
||||
from rerank_server import SentenceTransformerCrossEncoder
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def convert_title_format(text):
|
||||
# Use regex to extract the title and the content
|
||||
match = re.match(r'\(Title:\s*([^)]+)\)\s*(.+)', text, re.DOTALL)
|
||||
if match:
|
||||
title, content = match.groups()
|
||||
return f'\"{title}\"\n{content}'
|
||||
else:
|
||||
return text
|
||||
|
||||
# ----------- Combined Request Schema -----------
|
||||
class SearchRequest(BaseModel):
|
||||
queries: List[str]
|
||||
topk_retrieval: Optional[int] = 10
|
||||
topk_rerank: Optional[int] = 3
|
||||
return_scores: bool = False
|
||||
|
||||
# ----------- Reranker Config Schema -----------
|
||||
@dataclass
|
||||
class RerankerArguments:
|
||||
max_length: int = field(default=512)
|
||||
rerank_topk: int = field(default=3)
|
||||
rerank_model_name_or_path: str = field(default="cross-encoder/ms-marco-MiniLM-L12-v2")
|
||||
batch_size: int = field(default=32)
|
||||
reranker_type: str = field(default="sentence_transformer")
|
||||
|
||||
def get_reranker(config):
|
||||
if config.reranker_type == "sentence_transformer":
|
||||
return SentenceTransformerCrossEncoder.load(
|
||||
config.rerank_model_name_or_path,
|
||||
batch_size=config.batch_size,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown reranker type: {config.reranker_type}")
|
||||
|
||||
# ----------- Endpoint -----------
|
||||
@app.post("/retrieve")
|
||||
def search_endpoint(request: SearchRequest):
|
||||
# Step 1: Retrieve documents
|
||||
retrieved_docs = retriever.batch_search(
|
||||
query_list=request.queries,
|
||||
num=request.topk_retrieval,
|
||||
return_score=False
|
||||
)
|
||||
|
||||
# Step 2: Rerank
|
||||
reranked = reranker.rerank(request.queries, retrieved_docs)
|
||||
|
||||
# Step 3: Format response
|
||||
response = []
|
||||
for i, doc_scores in reranked.items():
|
||||
doc_scores = doc_scores[:request.topk_rerank]
|
||||
if request.return_scores:
|
||||
combined = []
|
||||
for doc, score in doc_scores:
|
||||
combined.append({"document": convert_title_format(doc), "score": score})
|
||||
response.append(combined)
|
||||
else:
|
||||
response.append([convert_title_format(doc) for doc, _ in doc_scores])
|
||||
|
||||
return {"result": response}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||
# retriever
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--retrieval_topk", type=int, default=10, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||
# reranker
|
||||
parser.add_argument("--reranking_topk", type=int, default=3, help="Number of reranked passages for one query.")
|
||||
parser.add_argument("--reranker_model", type=str, default="cross-encoder/ms-marco-MiniLM-L12-v2", help="Path of the reranker model.")
|
||||
parser.add_argument("--reranker_batch_size", type=int, default=32, help="Batch size for the reranker inference.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# ----------- Load Retriever and Reranker -----------
|
||||
retriever_config = RetrieverConfig(
|
||||
retrieval_method = args.retriever_name,
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.retrieval_topk,
|
||||
faiss_gpu=args.faiss_gpu,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
retrieval_use_fp16=True,
|
||||
retrieval_batch_size=512,
|
||||
)
|
||||
retriever = get_retriever(retriever_config)
|
||||
|
||||
reranker_config = RerankerArguments(
|
||||
rerank_topk = args.reranking_topk,
|
||||
rerank_model_name_or_path = args.reranker_model,
|
||||
batch_size = args.reranker_batch_size,
|
||||
)
|
||||
reranker = get_reranker(reranker_config)
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user