Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
corpus_file=/your/corpus/jsonl/file # jsonl
save_dir=/the/path/to/save/index
retriever_name=e5 # this is for indexing naming
retriever_model=intfloat/e5-base-v2
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python index_builder.py \
--retrieval_method $retriever_name \
--model_path $retriever_model \
--corpus_path $corpus_file \
--save_dir $save_dir \
--use_fp16 \
--max_length 256 \
--batch_size 512 \
--pooling_method mean \
--faiss_type Flat \
--save_embedding

View File

@@ -0,0 +1,348 @@
import os
import faiss
import json
import warnings
import numpy as np
from typing import cast, List, Dict
import shutil
import subprocess
import argparse
import torch
from tqdm import tqdm
# from LongRAG.retriever.utils import load_model, load_corpus, pooling
import datasets
from transformers import AutoTokenizer, AutoModel, AutoConfig
def load_model(
model_path: str,
use_fp16: bool = False
):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4)
return corpus
class Index_Builder:
r"""A tool class used to build an index used in retrieval.
"""
def __init__(
self,
retrieval_method,
model_path,
corpus_path,
save_dir,
max_length,
batch_size,
use_fp16,
pooling_method,
faiss_type=None,
embedding_path=None,
save_embedding=False,
faiss_gpu=False
):
self.retrieval_method = retrieval_method.lower()
self.model_path = model_path
self.corpus_path = corpus_path
self.save_dir = save_dir
self.max_length = max_length
self.batch_size = batch_size
self.use_fp16 = use_fp16
self.pooling_method = pooling_method
self.faiss_type = faiss_type if faiss_type is not None else 'Flat'
self.embedding_path = embedding_path
self.save_embedding = save_embedding
self.faiss_gpu = faiss_gpu
self.gpu_num = torch.cuda.device_count()
# prepare save dir
print(self.save_dir)
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
else:
if not self._check_dir(self.save_dir):
warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning)
self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index")
self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap")
self.corpus = load_corpus(self.corpus_path)
print("Finish loading...")
@staticmethod
def _check_dir(dir_path):
r"""Check if the dir path exists and if there is content.
"""
if os.path.isdir(dir_path):
if len(os.listdir(dir_path)) > 0:
return False
else:
os.makedirs(dir_path, exist_ok=True)
return True
def build_index(self):
r"""Constructing different indexes based on selective retrieval method.
"""
if self.retrieval_method == "bm25":
self.build_bm25_index()
else:
self.build_dense_index()
def build_bm25_index(self):
"""Building BM25 index based on Pyserini library.
Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
"""
# to use pyserini pipeline, we first need to place jsonl file in the folder
self.save_dir = os.path.join(self.save_dir, "bm25")
os.makedirs(self.save_dir, exist_ok=True)
temp_dir = self.save_dir + "/temp"
temp_file_path = temp_dir + "/temp.jsonl"
os.makedirs(temp_dir)
# if self.have_contents:
# shutil.copyfile(self.corpus_path, temp_file_path)
# else:
# with open(temp_file_path, "w") as f:
# for item in self.corpus:
# f.write(json.dumps(item) + "\n")
shutil.copyfile(self.corpus_path, temp_file_path)
print("Start building bm25 index...")
pyserini_args = ["--collection", "JsonCollection",
"--input", temp_dir,
"--index", self.save_dir,
"--generator", "DefaultLuceneDocumentGenerator",
"--threads", "1"]
subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args)
shutil.rmtree(temp_dir)
print("Finish!")
def _load_embedding(self, embedding_path, corpus_size, hidden_size):
all_embeddings = np.memmap(
embedding_path,
mode="r",
dtype=np.float32
).reshape(corpus_size, hidden_size)
return all_embeddings
def _save_embedding(self, all_embeddings):
memmap = np.memmap(
self.embedding_save_path,
shape=all_embeddings.shape,
mode="w+",
dtype=all_embeddings.dtype
)
length = all_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = all_embeddings[i: j]
else:
memmap[:] = all_embeddings
def encode_all(self):
if self.gpu_num > 1:
print("Use multi gpu!")
self.encoder = torch.nn.DataParallel(self.encoder)
self.batch_size = self.batch_size * self.gpu_num
all_embeddings = []
for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'):
batch_data_title = self.corpus[start_idx:start_idx+self.batch_size]['title']
batch_data_text = self.corpus[start_idx:start_idx+self.batch_size]['text']
batch_data = ['"' + title + '"\n' + text for title, text in zip(batch_data_title, batch_data_text)]
if self.retrieval_method == "e5":
batch_data = [f"passage: {doc}" for doc in batch_data]
inputs = self.tokenizer(
batch_data,
padding=True,
truncation=True,
return_tensors='pt',
max_length=self.max_length,
).to('cuda')
inputs = {k: v.cuda() for k, v in inputs.items()}
#TODO: support encoder-only T5 model
if "T5" in type(self.encoder).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.encoder(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
embeddings = output.last_hidden_state[:, 0, :]
else:
output = self.encoder(**inputs, return_dict=True)
embeddings = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.retrieval_method:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
embeddings = embeddings.detach().cpu().numpy()
all_embeddings.append(embeddings)
all_embeddings = np.concatenate(all_embeddings, axis=0)
all_embeddings = all_embeddings.astype(np.float32)
return all_embeddings
@torch.no_grad()
def build_dense_index(self):
"""Obtain the representation of documents based on the embedding model(BERT-based) and
construct a faiss index.
"""
if os.path.exists(self.index_save_path):
print("The index file already exists and will be overwritten.")
self.encoder, self.tokenizer = load_model(model_path = self.model_path,
use_fp16 = self.use_fp16)
if self.embedding_path is not None:
hidden_size = self.encoder.config.hidden_size
corpus_size = len(self.corpus)
all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
else:
all_embeddings = self.encode_all()
if self.save_embedding:
self._save_embedding(all_embeddings)
del self.corpus
# build index
print("Creating index")
dim = all_embeddings.shape[-1]
faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT)
if self.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
if not faiss_index.is_trained:
faiss_index.train(all_embeddings)
faiss_index.add(all_embeddings)
faiss_index = faiss.index_gpu_to_cpu(faiss_index)
else:
if not faiss_index.is_trained:
faiss_index.train(all_embeddings)
faiss_index.add(all_embeddings)
faiss.write_index(faiss_index, self.index_save_path)
print("Finish!")
MODEL2POOLING = {
"e5": "mean",
"bge": "cls",
"contriever": "mean",
'jina': 'mean'
}
def main():
parser = argparse.ArgumentParser(description = "Creating index.")
# Basic parameters
parser.add_argument('--retrieval_method', type=str)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--corpus_path', type=str)
parser.add_argument('--save_dir', default= 'indexes/',type=str)
# Parameters for building dense index
parser.add_argument('--max_length', type=int, default=180)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--use_fp16', default=False, action='store_true')
parser.add_argument('--pooling_method', type=str, default=None)
parser.add_argument('--faiss_type',default=None,type=str)
parser.add_argument('--embedding_path', default=None, type=str)
parser.add_argument('--save_embedding', action='store_true', default=False)
parser.add_argument('--faiss_gpu', default=False, action='store_true')
args = parser.parse_args()
if args.pooling_method is None:
pooling_method = 'mean'
for k,v in MODEL2POOLING.items():
if k in args.retrieval_method.lower():
pooling_method = v
break
else:
if args.pooling_method not in ['mean','cls','pooler']:
raise NotImplementedError
else:
pooling_method = args.pooling_method
index_builder = Index_Builder(
retrieval_method = args.retrieval_method,
model_path = args.model_path,
corpus_path = args.corpus_path,
save_dir = args.save_dir,
max_length = args.max_length,
batch_size = args.batch_size,
use_fp16 = args.use_fp16,
pooling_method = pooling_method,
faiss_type = args.faiss_type,
embedding_path = args.embedding_path,
save_embedding = args.save_embedding,
faiss_gpu = args.faiss_gpu
)
index_builder.build_index()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,368 @@
import json
import os
import warnings
from typing import List, Dict
import functools
from tqdm import tqdm
from multiprocessing import Pool
import faiss
import torch
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoModel
import argparse
import datasets
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4)
return corpus
def read_jsonl(file_path):
data = []
with open(file_path, "r") as f:
readin = f.readlines()
for line in readin:
data.append(json.loads(line))
return data
def load_docs(corpus, doc_idxs):
results = [corpus[int(idx)] for idx in doc_idxs]
return results
def load_model(
model_path: str,
use_fp16: bool = False
):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
self.model_path = model_path
self.pooling_method = pooling_method
self.max_length = max_length
self.use_fp16 = use_fp16
self.model, self.tokenizer = load_model(model_path=model_path,
use_fp16=use_fp16)
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]
if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
inputs = self.tokenizer(query_list,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.model(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
query_emb = output.last_hidden_state[:, 0, :]
else:
output = self.model(**inputs, return_dict=True)
query_emb = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb
class BaseRetriever:
"""Base object for all retrievers."""
def __init__(self, config):
self.config = config
self.retrieval_method = config.retrieval_method
self.topk = config.retrieval_topk
self.index_path = config.index_path
self.corpus_path = config.corpus_path
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
r"""Retrieve topk relevant documents in corpus.
Return:
list: contains information related to the document, including:
contents: used for building index
title: (if provided)
text: (if provided)
"""
pass
def _batch_search(self, query_list, num, return_score):
pass
def search(self, *args, **kwargs):
return self._search(*args, **kwargs)
def batch_search(self, *args, **kwargs):
return self._batch_search(*args, **kwargs)
class BM25Retriever(BaseRetriever):
r"""BM25 retriever based on pre-built pyserini index."""
def __init__(self, config):
super().__init__(config)
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(self.index_path)
self.contain_doc = self._check_contain_doc()
if not self.contain_doc:
self.corpus = load_corpus(self.corpus_path)
self.max_process_num = 8
def _check_contain_doc(self):
r"""Check if the index contains document content
"""
return self.searcher.doc(0).raw() is not None
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
if len(hits) < 1:
if return_score:
return [],[]
else:
return []
scores = [hit.score for hit in hits]
if len(hits) < num:
warnings.warn('Not enough documents retrieved!')
else:
hits = hits[:num]
if self.contain_doc:
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
results = [{'title': content.split("\n")[0].strip("\""),
'text': "\n".join(content.split("\n")[1:]),
'contents': content} for content in all_contents]
else:
results = load_docs(self.corpus, [hit.docid for hit in hits])
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list, num: int = None, return_score = False):
# TODO: modify batch method
results = []
scores = []
for query in query_list:
item_result, item_score = self._search(query, num,True)
results.append(item_result)
scores.append(item_score)
if return_score:
return results, scores
else:
return results
def get_available_gpu_memory():
memory_info = []
for i in range(torch.cuda.device_count()):
total_memory = torch.cuda.get_device_properties(i).total_memory
allocated_memory = torch.cuda.memory_allocated(i)
free_memory = total_memory - allocated_memory
memory_info.append((i, free_memory / 1e9)) # Convert to GB
return memory_info
class DenseRetriever(BaseRetriever):
r"""Dense retriever based on pre-built faiss index."""
def __init__(self, config: dict):
super().__init__(config)
self.index = faiss.read_index(self.index_path)
if config.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
# self.index = faiss.index_cpu_to_all_gpus(self.index)
self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config.retrieval_model_path,
pooling_method = config.retrieval_pooling_method,
max_length = config.retrieval_query_max_length,
use_fp16 = config.retrieval_use_fp16
)
self.topk = config.retrieval_topk
self.batch_size = self.config.retrieval_batch_size
def _search(self, query: str, num: int = None, return_score = False):
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)
idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
num = self.topk
batch_size = self.batch_size
results = []
scores = []
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
query_batch = query_list[start_idx:start_idx + batch_size]
# from time import time
# a = time()
batch_emb = self.encoder.encode(query_batch)
# b = time()
# print(f'################### encode time {b-a} #####################')
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()
# print(f'################### search time {time()-b} #####################')
# exit()
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
scores.extend(batch_scores)
results.extend(batch_results)
if return_score:
return results, scores
else:
return results
def get_retriever(config):
r"""Automatically select retriever class based on config's retrieval method
Args:
config (dict): configuration with 'retrieval_method' key
Returns:
Retriever: retriever instance
"""
if config.retrieval_method == "bm25":
return BM25Retriever(config)
else:
return DenseRetriever(config)
def get_dataset(config):
"""Load dataset from config."""
split_path = os.path.join(config.dataset_path, f'{config.data_split}.jsonl')
return read_jsonl(split_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "Retrieval")
# Basic parameters
parser.add_argument('--retrieval_method', type=str)
parser.add_argument('--retrieval_topk', type=int, default=10)
parser.add_argument('--index_path', type=str, default=None)
parser.add_argument('--corpus_path', type=str)
parser.add_argument('--dataset_path', default=None, type=str)
parser.add_argument('--faiss_gpu', default=True, type=bool)
parser.add_argument('--data_split', default="train", type=str)
parser.add_argument('--retrieval_model_path', type=str, default=None)
parser.add_argument('--retrieval_pooling_method', default='mean', type=str)
parser.add_argument('--retrieval_query_max_length', default=256, type=str)
parser.add_argument('--retrieval_use_fp16', action='store_true', default=False)
parser.add_argument('--retrieval_batch_size', default=512, type=int)
args = parser.parse_args()
args.index_path = os.path.join(args.index_path, f'{args.retrieval_method}_Flat.index') if args.retrieval_method != 'bm25' else os.path.join(args.index_path, 'bm25')
# load dataset
all_split = get_dataset(args)
input_query = [sample['question'] for sample in all_split[:512]]
# initialize the retriever and conduct retrieval
retriever = get_retriever(args)
print('Start Retrieving ...')
results, scores = retriever.batch_search(input_query, return_score=True)
# from IPython import embed
# embed()

View File

@@ -0,0 +1,25 @@
DATA_NAME=nq
DATASET_PATH="/home/peterjin/mnt/data/$DATA_NAME"
SPLIT='test'
TOPK=3
INDEX_PATH=/home/peterjin/mnt/index/wiki-18
CORPUS_PATH=/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl
SAVE_NAME=e5_${TOPK}_wiki18.json
# INDEX_PATH=/home/peterjin/rm_retrieval_corpus/index/wiki-21
# CORPUS_PATH=/home/peterjin/rm_retrieval_corpus/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
# SAVE_NAME=e5_${TOPK}_wiki21.json
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python retrieval.py --retrieval_method e5 \
--retrieval_topk $TOPK \
--index_path $INDEX_PATH \
--corpus_path $CORPUS_PATH \
--dataset_path $DATASET_PATH \
--data_split $SPLIT \
--retrieval_model_path "intfloat/e5-base-v2" \
--retrieval_pooling_method "mean" \
--retrieval_batch_size 512 \

View File

@@ -0,0 +1,23 @@
import requests
# URL for your local FastAPI server
url = "http://127.0.0.1:8000/retrieve"
# Example payload
payload = {
"queries": ["What is the capital of France?", "Explain neural networks."] * 200,
"topk": 5,
"return_scores": True
}
# Send POST request
response = requests.post(url, json=payload)
# Raise an exception if the request failed
response.raise_for_status()
# Get the JSON response
retrieved_data = response.json()
print("Response from server:")
print(retrieved_data)

View File

@@ -0,0 +1,382 @@
import json
import os
import warnings
from typing import List, Dict, Optional
import argparse
import faiss
import torch
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoModel
from tqdm import tqdm
import datasets
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
args = parser.parse_args()
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4
)
return corpus
def read_jsonl(file_path):
data = []
with open(file_path, "r") as f:
for line in f:
data.append(json.loads(line))
return data
def load_docs(corpus, doc_idxs):
results = [corpus[int(idx)] for idx in doc_idxs]
return results
def load_model(model_path: str, use_fp16: bool = False):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
self.model_path = model_path
self.pooling_method = pooling_method
self.max_length = max_length
self.use_fp16 = use_fp16
self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
self.model.eval()
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]
if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
inputs = self.tokenizer(query_list,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.model(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
query_emb = output.last_hidden_state[:, 0, :]
else:
output = self.model(**inputs, return_dict=True)
query_emb = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb
class BaseRetriever:
def __init__(self, config):
self.config = config
self.retrieval_method = config.retrieval_method
self.topk = config.retrieval_topk
self.index_path = config.index_path
self.corpus_path = config.corpus_path
def _search(self, query: str, num: int, return_score: bool):
raise NotImplementedError
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
raise NotImplementedError
def search(self, query: str, num: int = None, return_score: bool = False):
return self._search(query, num, return_score)
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
return self._batch_search(query_list, num, return_score)
class BM25Retriever(BaseRetriever):
def __init__(self, config):
super().__init__(config)
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(self.index_path)
self.contain_doc = self._check_contain_doc()
if not self.contain_doc:
self.corpus = load_corpus(self.corpus_path)
self.max_process_num = 8
def _check_contain_doc(self):
return self.searcher.doc(0).raw() is not None
def _search(self, query: str, num: int = None, return_score: bool = False):
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
if len(hits) < 1:
if return_score:
return [], []
else:
return []
scores = [hit.score for hit in hits]
if len(hits) < num:
warnings.warn('Not enough documents retrieved!')
else:
hits = hits[:num]
if self.contain_doc:
all_contents = [
json.loads(self.searcher.doc(hit.docid).raw())['contents']
for hit in hits
]
results = [
{
'title': content.split("\n")[0].strip("\""),
'text': "\n".join(content.split("\n")[1:]),
'contents': content
}
for content in all_contents
]
else:
results = load_docs(self.corpus, [hit.docid for hit in hits])
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
results = []
scores = []
for query in query_list:
item_result, item_score = self._search(query, num, True)
results.append(item_result)
scores.append(item_score)
if return_score:
return results, scores
else:
return results
class DenseRetriever(BaseRetriever):
def __init__(self, config):
super().__init__(config)
self.index = faiss.read_index(self.index_path)
if config.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config.retrieval_model_path,
pooling_method = config.retrieval_pooling_method,
max_length = config.retrieval_query_max_length,
use_fp16 = config.retrieval_use_fp16
)
self.topk = config.retrieval_topk
self.batch_size = config.retrieval_batch_size
def _search(self, query: str, num: int = None, return_score: bool = False):
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)
idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
if return_score:
return results, scores.tolist()
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
num = self.topk
results = []
scores = []
for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
query_batch = query_list[start_idx:start_idx + self.batch_size]
batch_emb = self.encoder.encode(query_batch)
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()
# load_docs is not vectorized, but is a python list approach
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
# chunk them back
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
results.extend(batch_results)
scores.extend(batch_scores)
if return_score:
return results, scores
else:
return results
def get_retriever(config):
if config.retrieval_method == "bm25":
return BM25Retriever(config)
else:
return DenseRetriever(config)
#####################################
# FastAPI server below
#####################################
class Config:
"""
Minimal config class (simulating your argparse)
Replace this with your real arguments or load them dynamically.
"""
def __init__(
self,
retrieval_method: str = "bm25",
retrieval_topk: int = 10,
index_path: str = "./index/bm25",
corpus_path: str = "./data/corpus.jsonl",
dataset_path: str = "./data",
data_split: str = "train",
faiss_gpu: bool = True,
retrieval_model_path: str = "./model",
retrieval_pooling_method: str = "mean",
retrieval_query_max_length: int = 256,
retrieval_use_fp16: bool = False,
retrieval_batch_size: int = 128
):
self.retrieval_method = retrieval_method
self.retrieval_topk = retrieval_topk
self.index_path = index_path
self.corpus_path = corpus_path
self.dataset_path = dataset_path
self.data_split = data_split
self.faiss_gpu = faiss_gpu
self.retrieval_model_path = retrieval_model_path
self.retrieval_pooling_method = retrieval_pooling_method
self.retrieval_query_max_length = retrieval_query_max_length
self.retrieval_use_fp16 = retrieval_use_fp16
self.retrieval_batch_size = retrieval_batch_size
class QueryRequest(BaseModel):
queries: List[str]
topk: Optional[int] = None
return_scores: bool = False
app = FastAPI()
# 1) Build a config (could also parse from arguments).
# In real usage, you'd parse your CLI arguments or environment variables.
config = Config(
retrieval_method = "e5", # or "dense"
index_path=args.index_path,
corpus_path=args.corpus_path,
retrieval_topk=args.topk,
faiss_gpu=True,
retrieval_model_path=args.retriever_model,
retrieval_pooling_method="mean",
retrieval_query_max_length=256,
retrieval_use_fp16=True,
retrieval_batch_size=512,
)
# 2) Instantiate a global retriever so it is loaded once and reused.
retriever = get_retriever(config)
@app.post("/retrieve")
def retrieve_endpoint(request: QueryRequest):
"""
Endpoint that accepts queries and performs retrieval.
Input format:
{
"queries": ["What is Python?", "Tell me about neural networks."],
"topk": 3,
"return_scores": true
}
"""
if not request.topk:
request.topk = config.retrieval_topk # fallback to default
# Perform batch retrieval
results, scores = retriever.batch_search(
query_list=request.queries,
num=request.topk,
return_score=request.return_scores
)
# Format response
resp = []
for i, single_result in enumerate(results):
if request.return_scores:
# If scores are returned, combine them with results
combined = []
for doc, score in zip(single_result, scores[i]):
combined.append({"document": doc, "score": score})
resp.append(combined)
else:
resp.append(single_result)
return {"result": resp}
if __name__ == "__main__":
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
uvicorn.run(app, host="0.0.0.0", port=8000)