Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

0
search_r1/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,416 @@
import torch
import re
from collections import defaultdict
import os
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
from .tensor_helper import TensorHelper, TensorConfig
# from search_r1.utils import set_seed
# from search_r1.utils.plot import (
# save_trajectory_to_output,
# parse_llm_output
# )
from verl import DataProto
from verl.utils.tracking import Tracking
import shutil
import requests
@dataclass
class GenerationConfig:
max_turns: int
max_start_length: int
max_prompt_length: int
max_response_length: int
max_obs_length: int
# logging: dict
num_gpus: int
no_think_rl: bool=False
search_url: str = None
topk: int = 3
class LLMGenerationManager:
def __init__(
self,
tokenizer,
actor_rollout_wg,
config: GenerationConfig,
# logger: Tracking,
is_validation: bool = False,
):
self.tokenizer = tokenizer
self.actor_rollout_wg = actor_rollout_wg
self.config = config
# self.logger = logger
self.is_validation = is_validation
self.tensor_fn = TensorHelper(TensorConfig(
pad_token_id=tokenizer.pad_token_id,
max_prompt_length=config.max_prompt_length,
max_obs_length=config.max_obs_length,
max_start_length=config.max_start_length
))
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
"""Tokenize a batch of responses."""
return self.tokenizer(
responses,
add_special_tokens=False,
return_tensors='pt',
padding="longest"
)['input_ids']
def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor:
"""Process responses to stop at search operation or answer operation."""
responses_str = self.tokenizer.batch_decode(
responses,
skip_special_tokens=True
)
responses_str = [resp.split('</search>')[0] + '</search>'
if '</search>' in resp
else resp.split('</answer>')[0] + '</answer>'
if '</answer>' in resp
else resp
for resp in responses_str]
if self.config.no_think_rl:
raise ValueError('stop')
# if no_think_rl is enabled, only keep action in the str
actions, _ = self.env.postprocess_predictions(responses_str)
responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
print("RESPONSES:", responses_str)
responses = self._batch_tokenize(responses_str)
return responses, responses_str
def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
"""Process next observations from environment."""
next_obs_ids = self.tokenizer(
next_obs,
padding='longest',
return_tensors='pt',
add_special_tokens=False, # Prevents adding special tokens
)['input_ids']
if next_obs_ids.shape[1] > self.config.max_obs_length:
print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}")
next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
return next_obs_ids
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
next_obs_ids: torch.Tensor) -> Dict:
"""Update rolling state with new responses and observations."""
# Concatenate and handle padding
new_input_ids = self.tensor_fn.concatenate_with_padding([
rollings.batch['input_ids'],
cur_responses,
next_obs_ids
])
# Create attention mask and position ids
new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
# Cut to appropriate length
effective_len = new_attention_mask.sum(dim=1).max()
max_len = min(self.config.max_prompt_length, effective_len)
return DataProto.from_dict({
'input_ids': new_input_ids[:, -max_len:],
'position_ids': new_position_ids[:, -max_len:],
'attention_mask': new_attention_mask[:, -max_len:]
})
def _update_right_side(self, right_side: Dict,
cur_responses: torch.Tensor,
next_obs_ids: torch.Tensor = None) -> Dict:
"""Update right side state."""
if next_obs_ids != None:
responses = self.tensor_fn.concatenate_with_padding([
right_side['responses'],
cur_responses,
next_obs_ids
], pad_to_left=False)
else:
responses = self.tensor_fn.concatenate_with_padding([
right_side['responses'],
cur_responses,
], pad_to_left=False)
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
max_len = min(self.config.max_prompt_length, effective_len)
return {'responses': responses[:, :max_len]}
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
"""
Wrapper for generation that handles multi-GPU padding requirements.
if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
if active_batch size is not divisible by num_gpus, pad with first sequence
then remove padding from output
"""
num_gpus = self.config.num_gpus
if num_gpus <= 1:
return self.actor_rollout_wg.generate_sequences(active_batch)
batch_size = active_batch.batch['input_ids'].shape[0]
remainder = batch_size % num_gpus
if remainder == 0:
return self.actor_rollout_wg.generate_sequences(active_batch)
# Add padding sequences
padding_size = num_gpus - remainder
padded_batch = {}
for k, v in active_batch.batch.items():
# Use first sequence as padding template
pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
padded_active_batch = DataProto.from_dict(padded_batch)
# Generate with padded batch
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
# Remove padding from output
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
# Handle meta_info if present
if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
trimmed_meta = {}
for k, v in padded_output.meta_info.items():
if isinstance(v, torch.Tensor):
trimmed_meta[k] = v[:-padding_size]
else:
trimmed_meta[k] = v
padded_output.meta_info = trimmed_meta
padded_output.batch = trimmed_batch
return padded_output
def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:
"""Run main LLM generation loop."""
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
original_right_side = {'responses': initial_input_ids[:, []]}
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
active_num_list = [active_mask.sum().item()]
rollings = gen_batch
# Main generation loop
for step in range(self.config.max_turns):
if not active_mask.sum():
break
rollings.batch = self.tensor_fn.cut_to_effective_len(
rollings.batch,
keys=['input_ids', 'attention_mask', 'position_ids']
)
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
rollings_active = DataProto.from_dict({
k: v[active_mask] for k, v in rollings.batch.items()
})
gen_output = self._generate_with_gpu_padding(rollings_active)
meta_info = gen_output.meta_info
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# Execute in environment and process observations
next_obs, dones = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
next_obs_ids = self._process_next_obs(next_obs)
# Update states
rollings = self._update_rolling_state(
rollings,
responses_ids,
next_obs_ids
)
original_right_side = self._update_right_side(
original_right_side,
responses_ids,
next_obs_ids
)
# final LLM rollout
if active_mask.sum():
rollings.batch = self.tensor_fn.cut_to_effective_len(
rollings.batch,
keys=['input_ids', 'attention_mask', 'position_ids']
)
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
rollings_active = DataProto.from_dict({
k: v[active_mask] for k, v in rollings.batch.items()
})
gen_output = self._generate_with_gpu_padding(rollings_active)
meta_info = gen_output.meta_info
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# # Execute in environment and process observations
_, dones = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
original_right_side = self._update_right_side(
original_right_side,
responses_ids,
)
print("ACTIVE_TRAJ_NUM:", active_num_list)
return self._compose_final_output(original_left_side, original_right_side, meta_info)
def _compose_final_output(self, left_side: Dict,
right_side: Dict,
meta_info: Dict) -> Tuple[Dict, Dict]:
"""Compose final generation output."""
final_output = right_side.copy()
final_output['prompts'] = left_side['input_ids']
# Combine input IDs
final_output['input_ids'] = torch.cat([
left_side['input_ids'],
right_side['responses']
], dim=1)
# Create attention mask and position ids
final_output['attention_mask'] = torch.cat([
self.tensor_fn.create_attention_mask(left_side['input_ids']),
self.tensor_fn.create_attention_mask(final_output['responses'])
], dim=1)
final_output['position_ids'] = self.tensor_fn.create_position_ids(
final_output['attention_mask']
)
final_output = DataProto.from_dict(final_output)
final_output.meta_info.update(meta_info)
return final_output
def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]:
"""
Execute predictions across multiple environments.
NOTE: the function is the actual `step` function in the environment
NOTE penalty_for_invalid is not included in observation shown to the LLM
Args:
envs: List of environment instances
predictions: List of action predictions
pad_token: Token to use for padding
Returns:
List of observation strings
"""
cur_actions, contents = self.postprocess_predictions(predictions)
next_obs, dones = [], []
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
if do_search:
search_results = self.batch_search(search_queries)
assert len(search_results) == sum([1 for action in cur_actions if action == 'search'])
else:
search_results = [''] * sum([1 for action in cur_actions if action == 'search'])
for i, (action, active) in enumerate(zip(cur_actions, active_mask)):
if not active:
next_obs.append('')
dones.append(1)
else:
if action == 'answer':
next_obs.append('')
dones.append(1)
elif action == 'search':
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
dones.append(0)
else:
next_obs.append(f'\nMy previous action is invalid. \
If I want to search, I should put the query between <search> and </search>. \
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
dones.append(0)
assert len(search_results) == 0
return next_obs, dones
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
"""
Process (text-based) predictions from llm into actions and validity flags.
Args:
predictions: List of raw predictions
Returns:
Tuple of (actions list, validity flags list)
"""
actions = []
contents = []
for prediction in predictions:
if isinstance(prediction, str): # for llm output
pattern = r'<(search|answer)>(.*?)</\1>'
match = re.search(pattern, prediction, re.DOTALL)
if match:
content = match.group(2).strip() # Return only the content inside the tags
action = match.group(1)
else:
content = ''
action = None
else:
raise ValueError(f"Invalid prediction type: {type(prediction)}")
actions.append(action)
contents.append(content)
return actions, contents
def batch_search(self, queries: List[str] = None) -> str:
"""
Batchified search for queries.
Args:
queries: queries to call the search engine
Returns:
search results which is concatenated into a string
"""
results = self._batch_search(queries)['result']
return [self._passages2string(result) for result in results]
def _batch_search(self, queries):
payload = {
"queries": queries,
"topk": self.config.topk,
"return_scores": True
}
return requests.post(self.config.search_url, json=payload).json()
def _passages2string(self, retrieval_result):
format_reference = ''
for idx, doc_item in enumerate(retrieval_result):
content = doc_item['document']['contents']
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
return format_reference

View File

@@ -0,0 +1,75 @@
import torch
from typing import Dict, Tuple, List
from dataclasses import dataclass
@dataclass
class TensorConfig:
pad_token_id: int
max_prompt_length: int
max_obs_length: int
max_start_length: int
class TensorHelper:
def __init__(self, config: TensorConfig):
self.config = config
def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor],
keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]:
"""Cut tensors to their effective length based on attention mask."""
effective_len = tensor_dict['attention_mask'].sum(dim=1).max()
result = tensor_dict.copy()
for key in keys:
if cut_left:
result[key] = tensor_dict[key][:, -effective_len:]
else:
result[key] = tensor_dict[key][:, :effective_len]
return result
def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert padding structure and return sorted tensor with indices."""
mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
return tensor.gather(1, sorted_indices), sorted_indices
def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Create attention mask from input ids."""
return torch.where(input_ids != self.config.pad_token_id, 1, 0)
def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
"""Create position ids from attention mask."""
return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask
def concatenate_with_padding(self, tensors: List[torch.Tensor],
pad_to_left: bool = True) -> torch.Tensor:
"""Concatenate tensors and handle padding."""
concatenated = torch.cat(tensors, dim=1)
padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left)
return padded_tensor
def _example_level_pad(self, responses: torch.Tensor,
responses_str: List[str],
active_mask: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
"""
Pad responses for non-active examples with pad tokens.
"""
assert active_mask.sum() == responses.shape[0]
# Create masked responses tensor
batch_size = active_mask.shape[0]
seq_len = responses.shape[1]
padded_responses = torch.full(
(batch_size, seq_len), self.config.pad_token_id,
dtype=responses.dtype, device=responses.device
)
padded_responses[active_mask] = responses
# Create masked response strings
padded_responses_str = [""] * batch_size
s = 0
for i, is_active in enumerate(active_mask):
if is_active:
padded_responses_str[i] = responses_str[s]
s += 1
return padded_responses, padded_responses_str

View File

@@ -0,0 +1,17 @@
corpus_file=/your/corpus/jsonl/file # jsonl
save_dir=/the/path/to/save/index
retriever_name=e5 # this is for indexing naming
retriever_model=intfloat/e5-base-v2
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python index_builder.py \
--retrieval_method $retriever_name \
--model_path $retriever_model \
--corpus_path $corpus_file \
--save_dir $save_dir \
--use_fp16 \
--max_length 256 \
--batch_size 512 \
--pooling_method mean \
--faiss_type Flat \
--save_embedding

View File

@@ -0,0 +1,348 @@
import os
import faiss
import json
import warnings
import numpy as np
from typing import cast, List, Dict
import shutil
import subprocess
import argparse
import torch
from tqdm import tqdm
# from LongRAG.retriever.utils import load_model, load_corpus, pooling
import datasets
from transformers import AutoTokenizer, AutoModel, AutoConfig
def load_model(
model_path: str,
use_fp16: bool = False
):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4)
return corpus
class Index_Builder:
r"""A tool class used to build an index used in retrieval.
"""
def __init__(
self,
retrieval_method,
model_path,
corpus_path,
save_dir,
max_length,
batch_size,
use_fp16,
pooling_method,
faiss_type=None,
embedding_path=None,
save_embedding=False,
faiss_gpu=False
):
self.retrieval_method = retrieval_method.lower()
self.model_path = model_path
self.corpus_path = corpus_path
self.save_dir = save_dir
self.max_length = max_length
self.batch_size = batch_size
self.use_fp16 = use_fp16
self.pooling_method = pooling_method
self.faiss_type = faiss_type if faiss_type is not None else 'Flat'
self.embedding_path = embedding_path
self.save_embedding = save_embedding
self.faiss_gpu = faiss_gpu
self.gpu_num = torch.cuda.device_count()
# prepare save dir
print(self.save_dir)
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
else:
if not self._check_dir(self.save_dir):
warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning)
self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index")
self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap")
self.corpus = load_corpus(self.corpus_path)
print("Finish loading...")
@staticmethod
def _check_dir(dir_path):
r"""Check if the dir path exists and if there is content.
"""
if os.path.isdir(dir_path):
if len(os.listdir(dir_path)) > 0:
return False
else:
os.makedirs(dir_path, exist_ok=True)
return True
def build_index(self):
r"""Constructing different indexes based on selective retrieval method.
"""
if self.retrieval_method == "bm25":
self.build_bm25_index()
else:
self.build_dense_index()
def build_bm25_index(self):
"""Building BM25 index based on Pyserini library.
Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
"""
# to use pyserini pipeline, we first need to place jsonl file in the folder
self.save_dir = os.path.join(self.save_dir, "bm25")
os.makedirs(self.save_dir, exist_ok=True)
temp_dir = self.save_dir + "/temp"
temp_file_path = temp_dir + "/temp.jsonl"
os.makedirs(temp_dir)
# if self.have_contents:
# shutil.copyfile(self.corpus_path, temp_file_path)
# else:
# with open(temp_file_path, "w") as f:
# for item in self.corpus:
# f.write(json.dumps(item) + "\n")
shutil.copyfile(self.corpus_path, temp_file_path)
print("Start building bm25 index...")
pyserini_args = ["--collection", "JsonCollection",
"--input", temp_dir,
"--index", self.save_dir,
"--generator", "DefaultLuceneDocumentGenerator",
"--threads", "1"]
subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args)
shutil.rmtree(temp_dir)
print("Finish!")
def _load_embedding(self, embedding_path, corpus_size, hidden_size):
all_embeddings = np.memmap(
embedding_path,
mode="r",
dtype=np.float32
).reshape(corpus_size, hidden_size)
return all_embeddings
def _save_embedding(self, all_embeddings):
memmap = np.memmap(
self.embedding_save_path,
shape=all_embeddings.shape,
mode="w+",
dtype=all_embeddings.dtype
)
length = all_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = all_embeddings[i: j]
else:
memmap[:] = all_embeddings
def encode_all(self):
if self.gpu_num > 1:
print("Use multi gpu!")
self.encoder = torch.nn.DataParallel(self.encoder)
self.batch_size = self.batch_size * self.gpu_num
all_embeddings = []
for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'):
batch_data_title = self.corpus[start_idx:start_idx+self.batch_size]['title']
batch_data_text = self.corpus[start_idx:start_idx+self.batch_size]['text']
batch_data = ['"' + title + '"\n' + text for title, text in zip(batch_data_title, batch_data_text)]
if self.retrieval_method == "e5":
batch_data = [f"passage: {doc}" for doc in batch_data]
inputs = self.tokenizer(
batch_data,
padding=True,
truncation=True,
return_tensors='pt',
max_length=self.max_length,
).to('cuda')
inputs = {k: v.cuda() for k, v in inputs.items()}
#TODO: support encoder-only T5 model
if "T5" in type(self.encoder).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.encoder(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
embeddings = output.last_hidden_state[:, 0, :]
else:
output = self.encoder(**inputs, return_dict=True)
embeddings = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.retrieval_method:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
embeddings = embeddings.detach().cpu().numpy()
all_embeddings.append(embeddings)
all_embeddings = np.concatenate(all_embeddings, axis=0)
all_embeddings = all_embeddings.astype(np.float32)
return all_embeddings
@torch.no_grad()
def build_dense_index(self):
"""Obtain the representation of documents based on the embedding model(BERT-based) and
construct a faiss index.
"""
if os.path.exists(self.index_save_path):
print("The index file already exists and will be overwritten.")
self.encoder, self.tokenizer = load_model(model_path = self.model_path,
use_fp16 = self.use_fp16)
if self.embedding_path is not None:
hidden_size = self.encoder.config.hidden_size
corpus_size = len(self.corpus)
all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
else:
all_embeddings = self.encode_all()
if self.save_embedding:
self._save_embedding(all_embeddings)
del self.corpus
# build index
print("Creating index")
dim = all_embeddings.shape[-1]
faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT)
if self.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
if not faiss_index.is_trained:
faiss_index.train(all_embeddings)
faiss_index.add(all_embeddings)
faiss_index = faiss.index_gpu_to_cpu(faiss_index)
else:
if not faiss_index.is_trained:
faiss_index.train(all_embeddings)
faiss_index.add(all_embeddings)
faiss.write_index(faiss_index, self.index_save_path)
print("Finish!")
MODEL2POOLING = {
"e5": "mean",
"bge": "cls",
"contriever": "mean",
'jina': 'mean'
}
def main():
parser = argparse.ArgumentParser(description = "Creating index.")
# Basic parameters
parser.add_argument('--retrieval_method', type=str)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--corpus_path', type=str)
parser.add_argument('--save_dir', default= 'indexes/',type=str)
# Parameters for building dense index
parser.add_argument('--max_length', type=int, default=180)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--use_fp16', default=False, action='store_true')
parser.add_argument('--pooling_method', type=str, default=None)
parser.add_argument('--faiss_type',default=None,type=str)
parser.add_argument('--embedding_path', default=None, type=str)
parser.add_argument('--save_embedding', action='store_true', default=False)
parser.add_argument('--faiss_gpu', default=False, action='store_true')
args = parser.parse_args()
if args.pooling_method is None:
pooling_method = 'mean'
for k,v in MODEL2POOLING.items():
if k in args.retrieval_method.lower():
pooling_method = v
break
else:
if args.pooling_method not in ['mean','cls','pooler']:
raise NotImplementedError
else:
pooling_method = args.pooling_method
index_builder = Index_Builder(
retrieval_method = args.retrieval_method,
model_path = args.model_path,
corpus_path = args.corpus_path,
save_dir = args.save_dir,
max_length = args.max_length,
batch_size = args.batch_size,
use_fp16 = args.use_fp16,
pooling_method = pooling_method,
faiss_type = args.faiss_type,
embedding_path = args.embedding_path,
save_embedding = args.save_embedding,
faiss_gpu = args.faiss_gpu
)
index_builder.build_index()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,368 @@
import json
import os
import warnings
from typing import List, Dict
import functools
from tqdm import tqdm
from multiprocessing import Pool
import faiss
import torch
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoModel
import argparse
import datasets
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4)
return corpus
def read_jsonl(file_path):
data = []
with open(file_path, "r") as f:
readin = f.readlines()
for line in readin:
data.append(json.loads(line))
return data
def load_docs(corpus, doc_idxs):
results = [corpus[int(idx)] for idx in doc_idxs]
return results
def load_model(
model_path: str,
use_fp16: bool = False
):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
self.model_path = model_path
self.pooling_method = pooling_method
self.max_length = max_length
self.use_fp16 = use_fp16
self.model, self.tokenizer = load_model(model_path=model_path,
use_fp16=use_fp16)
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]
if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
inputs = self.tokenizer(query_list,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.model(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
query_emb = output.last_hidden_state[:, 0, :]
else:
output = self.model(**inputs, return_dict=True)
query_emb = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb
class BaseRetriever:
"""Base object for all retrievers."""
def __init__(self, config):
self.config = config
self.retrieval_method = config.retrieval_method
self.topk = config.retrieval_topk
self.index_path = config.index_path
self.corpus_path = config.corpus_path
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
r"""Retrieve topk relevant documents in corpus.
Return:
list: contains information related to the document, including:
contents: used for building index
title: (if provided)
text: (if provided)
"""
pass
def _batch_search(self, query_list, num, return_score):
pass
def search(self, *args, **kwargs):
return self._search(*args, **kwargs)
def batch_search(self, *args, **kwargs):
return self._batch_search(*args, **kwargs)
class BM25Retriever(BaseRetriever):
r"""BM25 retriever based on pre-built pyserini index."""
def __init__(self, config):
super().__init__(config)
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(self.index_path)
self.contain_doc = self._check_contain_doc()
if not self.contain_doc:
self.corpus = load_corpus(self.corpus_path)
self.max_process_num = 8
def _check_contain_doc(self):
r"""Check if the index contains document content
"""
return self.searcher.doc(0).raw() is not None
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
if len(hits) < 1:
if return_score:
return [],[]
else:
return []
scores = [hit.score for hit in hits]
if len(hits) < num:
warnings.warn('Not enough documents retrieved!')
else:
hits = hits[:num]
if self.contain_doc:
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
results = [{'title': content.split("\n")[0].strip("\""),
'text': "\n".join(content.split("\n")[1:]),
'contents': content} for content in all_contents]
else:
results = load_docs(self.corpus, [hit.docid for hit in hits])
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list, num: int = None, return_score = False):
# TODO: modify batch method
results = []
scores = []
for query in query_list:
item_result, item_score = self._search(query, num,True)
results.append(item_result)
scores.append(item_score)
if return_score:
return results, scores
else:
return results
def get_available_gpu_memory():
memory_info = []
for i in range(torch.cuda.device_count()):
total_memory = torch.cuda.get_device_properties(i).total_memory
allocated_memory = torch.cuda.memory_allocated(i)
free_memory = total_memory - allocated_memory
memory_info.append((i, free_memory / 1e9)) # Convert to GB
return memory_info
class DenseRetriever(BaseRetriever):
r"""Dense retriever based on pre-built faiss index."""
def __init__(self, config: dict):
super().__init__(config)
self.index = faiss.read_index(self.index_path)
if config.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
# self.index = faiss.index_cpu_to_all_gpus(self.index)
self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config.retrieval_model_path,
pooling_method = config.retrieval_pooling_method,
max_length = config.retrieval_query_max_length,
use_fp16 = config.retrieval_use_fp16
)
self.topk = config.retrieval_topk
self.batch_size = self.config.retrieval_batch_size
def _search(self, query: str, num: int = None, return_score = False):
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)
idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
num = self.topk
batch_size = self.batch_size
results = []
scores = []
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
query_batch = query_list[start_idx:start_idx + batch_size]
# from time import time
# a = time()
batch_emb = self.encoder.encode(query_batch)
# b = time()
# print(f'################### encode time {b-a} #####################')
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()
# print(f'################### search time {time()-b} #####################')
# exit()
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
scores.extend(batch_scores)
results.extend(batch_results)
if return_score:
return results, scores
else:
return results
def get_retriever(config):
r"""Automatically select retriever class based on config's retrieval method
Args:
config (dict): configuration with 'retrieval_method' key
Returns:
Retriever: retriever instance
"""
if config.retrieval_method == "bm25":
return BM25Retriever(config)
else:
return DenseRetriever(config)
def get_dataset(config):
"""Load dataset from config."""
split_path = os.path.join(config.dataset_path, f'{config.data_split}.jsonl')
return read_jsonl(split_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "Retrieval")
# Basic parameters
parser.add_argument('--retrieval_method', type=str)
parser.add_argument('--retrieval_topk', type=int, default=10)
parser.add_argument('--index_path', type=str, default=None)
parser.add_argument('--corpus_path', type=str)
parser.add_argument('--dataset_path', default=None, type=str)
parser.add_argument('--faiss_gpu', default=True, type=bool)
parser.add_argument('--data_split', default="train", type=str)
parser.add_argument('--retrieval_model_path', type=str, default=None)
parser.add_argument('--retrieval_pooling_method', default='mean', type=str)
parser.add_argument('--retrieval_query_max_length', default=256, type=str)
parser.add_argument('--retrieval_use_fp16', action='store_true', default=False)
parser.add_argument('--retrieval_batch_size', default=512, type=int)
args = parser.parse_args()
args.index_path = os.path.join(args.index_path, f'{args.retrieval_method}_Flat.index') if args.retrieval_method != 'bm25' else os.path.join(args.index_path, 'bm25')
# load dataset
all_split = get_dataset(args)
input_query = [sample['question'] for sample in all_split[:512]]
# initialize the retriever and conduct retrieval
retriever = get_retriever(args)
print('Start Retrieving ...')
results, scores = retriever.batch_search(input_query, return_score=True)
# from IPython import embed
# embed()

View File

@@ -0,0 +1,25 @@
DATA_NAME=nq
DATASET_PATH="/home/peterjin/mnt/data/$DATA_NAME"
SPLIT='test'
TOPK=3
INDEX_PATH=/home/peterjin/mnt/index/wiki-18
CORPUS_PATH=/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl
SAVE_NAME=e5_${TOPK}_wiki18.json
# INDEX_PATH=/home/peterjin/rm_retrieval_corpus/index/wiki-21
# CORPUS_PATH=/home/peterjin/rm_retrieval_corpus/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
# SAVE_NAME=e5_${TOPK}_wiki21.json
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python retrieval.py --retrieval_method e5 \
--retrieval_topk $TOPK \
--index_path $INDEX_PATH \
--corpus_path $CORPUS_PATH \
--dataset_path $DATASET_PATH \
--data_split $SPLIT \
--retrieval_model_path "intfloat/e5-base-v2" \
--retrieval_pooling_method "mean" \
--retrieval_batch_size 512 \

View File

@@ -0,0 +1,23 @@
import requests
# URL for your local FastAPI server
url = "http://127.0.0.1:8000/retrieve"
# Example payload
payload = {
"queries": ["What is the capital of France?", "Explain neural networks."] * 200,
"topk": 5,
"return_scores": True
}
# Send POST request
response = requests.post(url, json=payload)
# Raise an exception if the request failed
response.raise_for_status()
# Get the JSON response
retrieved_data = response.json()
print("Response from server:")
print(retrieved_data)

View File

@@ -0,0 +1,382 @@
import json
import os
import warnings
from typing import List, Dict, Optional
import argparse
import faiss
import torch
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoModel
from tqdm import tqdm
import datasets
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
args = parser.parse_args()
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4
)
return corpus
def read_jsonl(file_path):
data = []
with open(file_path, "r") as f:
for line in f:
data.append(json.loads(line))
return data
def load_docs(corpus, doc_idxs):
results = [corpus[int(idx)] for idx in doc_idxs]
return results
def load_model(model_path: str, use_fp16: bool = False):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
self.model_path = model_path
self.pooling_method = pooling_method
self.max_length = max_length
self.use_fp16 = use_fp16
self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
self.model.eval()
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]
if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
inputs = self.tokenizer(query_list,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.model(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
query_emb = output.last_hidden_state[:, 0, :]
else:
output = self.model(**inputs, return_dict=True)
query_emb = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb
class BaseRetriever:
def __init__(self, config):
self.config = config
self.retrieval_method = config.retrieval_method
self.topk = config.retrieval_topk
self.index_path = config.index_path
self.corpus_path = config.corpus_path
def _search(self, query: str, num: int, return_score: bool):
raise NotImplementedError
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
raise NotImplementedError
def search(self, query: str, num: int = None, return_score: bool = False):
return self._search(query, num, return_score)
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
return self._batch_search(query_list, num, return_score)
class BM25Retriever(BaseRetriever):
def __init__(self, config):
super().__init__(config)
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(self.index_path)
self.contain_doc = self._check_contain_doc()
if not self.contain_doc:
self.corpus = load_corpus(self.corpus_path)
self.max_process_num = 8
def _check_contain_doc(self):
return self.searcher.doc(0).raw() is not None
def _search(self, query: str, num: int = None, return_score: bool = False):
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
if len(hits) < 1:
if return_score:
return [], []
else:
return []
scores = [hit.score for hit in hits]
if len(hits) < num:
warnings.warn('Not enough documents retrieved!')
else:
hits = hits[:num]
if self.contain_doc:
all_contents = [
json.loads(self.searcher.doc(hit.docid).raw())['contents']
for hit in hits
]
results = [
{
'title': content.split("\n")[0].strip("\""),
'text': "\n".join(content.split("\n")[1:]),
'contents': content
}
for content in all_contents
]
else:
results = load_docs(self.corpus, [hit.docid for hit in hits])
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
results = []
scores = []
for query in query_list:
item_result, item_score = self._search(query, num, True)
results.append(item_result)
scores.append(item_score)
if return_score:
return results, scores
else:
return results
class DenseRetriever(BaseRetriever):
def __init__(self, config):
super().__init__(config)
self.index = faiss.read_index(self.index_path)
if config.faiss_gpu:
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
co.shard = True
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config.retrieval_model_path,
pooling_method = config.retrieval_pooling_method,
max_length = config.retrieval_query_max_length,
use_fp16 = config.retrieval_use_fp16
)
self.topk = config.retrieval_topk
self.batch_size = config.retrieval_batch_size
def _search(self, query: str, num: int = None, return_score: bool = False):
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)
idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
if return_score:
return results, scores.tolist()
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
num = self.topk
results = []
scores = []
for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
query_batch = query_list[start_idx:start_idx + self.batch_size]
batch_emb = self.encoder.encode(query_batch)
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()
# load_docs is not vectorized, but is a python list approach
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
# chunk them back
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
results.extend(batch_results)
scores.extend(batch_scores)
if return_score:
return results, scores
else:
return results
def get_retriever(config):
if config.retrieval_method == "bm25":
return BM25Retriever(config)
else:
return DenseRetriever(config)
#####################################
# FastAPI server below
#####################################
class Config:
"""
Minimal config class (simulating your argparse)
Replace this with your real arguments or load them dynamically.
"""
def __init__(
self,
retrieval_method: str = "bm25",
retrieval_topk: int = 10,
index_path: str = "./index/bm25",
corpus_path: str = "./data/corpus.jsonl",
dataset_path: str = "./data",
data_split: str = "train",
faiss_gpu: bool = True,
retrieval_model_path: str = "./model",
retrieval_pooling_method: str = "mean",
retrieval_query_max_length: int = 256,
retrieval_use_fp16: bool = False,
retrieval_batch_size: int = 128
):
self.retrieval_method = retrieval_method
self.retrieval_topk = retrieval_topk
self.index_path = index_path
self.corpus_path = corpus_path
self.dataset_path = dataset_path
self.data_split = data_split
self.faiss_gpu = faiss_gpu
self.retrieval_model_path = retrieval_model_path
self.retrieval_pooling_method = retrieval_pooling_method
self.retrieval_query_max_length = retrieval_query_max_length
self.retrieval_use_fp16 = retrieval_use_fp16
self.retrieval_batch_size = retrieval_batch_size
class QueryRequest(BaseModel):
queries: List[str]
topk: Optional[int] = None
return_scores: bool = False
app = FastAPI()
# 1) Build a config (could also parse from arguments).
# In real usage, you'd parse your CLI arguments or environment variables.
config = Config(
retrieval_method = "e5", # or "dense"
index_path=args.index_path,
corpus_path=args.corpus_path,
retrieval_topk=args.topk,
faiss_gpu=True,
retrieval_model_path=args.retriever_model,
retrieval_pooling_method="mean",
retrieval_query_max_length=256,
retrieval_use_fp16=True,
retrieval_batch_size=512,
)
# 2) Instantiate a global retriever so it is loaded once and reused.
retriever = get_retriever(config)
@app.post("/retrieve")
def retrieve_endpoint(request: QueryRequest):
"""
Endpoint that accepts queries and performs retrieval.
Input format:
{
"queries": ["What is Python?", "Tell me about neural networks."],
"topk": 3,
"return_scores": true
}
"""
if not request.topk:
request.topk = config.retrieval_topk # fallback to default
# Perform batch retrieval
results, scores = retriever.batch_search(
query_list=request.queries,
num=request.topk,
return_score=request.return_scores
)
# Format response
resp = []
for i, single_result in enumerate(results):
if request.return_scores:
# If scores are returned, combine them with results
combined = []
for doc, score in zip(single_result, scores[i]):
combined.append({"document": doc, "score": score})
resp.append(combined)
else:
resp.append(single_result)
return {"result": resp}
if __name__ == "__main__":
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
uvicorn.run(app, host="0.0.0.0", port=8000)