add reranker
This commit is contained in:
@@ -15,17 +15,6 @@ import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def load_corpus(corpus_path: str):
|
||||
corpus = datasets.load_dataset(
|
||||
'json',
|
||||
@@ -334,24 +323,6 @@ class QueryRequest(BaseModel):
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 1) Build a config (could also parse from arguments).
|
||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||
config = Config(
|
||||
retrieval_method = args.retriever_name, # or "dense"
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.topk,
|
||||
faiss_gpu=args.faiss_gpu,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
retrieval_use_fp16=True,
|
||||
retrieval_batch_size=512,
|
||||
)
|
||||
|
||||
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||
retriever = get_retriever(config)
|
||||
|
||||
@app.post("/retrieve")
|
||||
def retrieve_endpoint(request: QueryRequest):
|
||||
"""
|
||||
@@ -388,5 +359,34 @@ def retrieve_endpoint(request: QueryRequest):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.")
|
||||
parser.add_argument('--faiss_gpu', action='store_true', help='Use GPU for computation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1) Build a config (could also parse from arguments).
|
||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||
config = Config(
|
||||
retrieval_method = args.retriever_name, # or "dense"
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.topk,
|
||||
faiss_gpu=args.faiss_gpu,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
retrieval_use_fp16=True,
|
||||
retrieval_batch_size=512,
|
||||
)
|
||||
|
||||
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||
retriever = get_retriever(config)
|
||||
|
||||
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
Reference in New Issue
Block a user