Initial commit
This commit is contained in:
0
search_r1/__init__.py
Normal file
0
search_r1/__init__.py
Normal file
0
search_r1/llm_agent/__init__.py
Normal file
0
search_r1/llm_agent/__init__.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import torch
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from .tensor_helper import TensorHelper, TensorConfig
|
||||
# from search_r1.utils import set_seed
|
||||
# from search_r1.utils.plot import (
|
||||
# save_trajectory_to_output,
|
||||
# parse_llm_output
|
||||
# )
|
||||
from verl import DataProto
|
||||
from verl.utils.tracking import Tracking
|
||||
import shutil
|
||||
import requests
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
max_turns: int
|
||||
max_start_length: int
|
||||
max_prompt_length: int
|
||||
max_response_length: int
|
||||
max_obs_length: int
|
||||
# logging: dict
|
||||
num_gpus: int
|
||||
no_think_rl: bool=False
|
||||
search_url: str = None
|
||||
topk: int = 3
|
||||
|
||||
class LLMGenerationManager:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
actor_rollout_wg,
|
||||
config: GenerationConfig,
|
||||
# logger: Tracking,
|
||||
is_validation: bool = False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_rollout_wg = actor_rollout_wg
|
||||
self.config = config
|
||||
# self.logger = logger
|
||||
self.is_validation = is_validation
|
||||
|
||||
self.tensor_fn = TensorHelper(TensorConfig(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_prompt_length=config.max_prompt_length,
|
||||
max_obs_length=config.max_obs_length,
|
||||
max_start_length=config.max_start_length
|
||||
))
|
||||
|
||||
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
|
||||
"""Tokenize a batch of responses."""
|
||||
return self.tokenizer(
|
||||
responses,
|
||||
add_special_tokens=False,
|
||||
return_tensors='pt',
|
||||
padding="longest"
|
||||
)['input_ids']
|
||||
|
||||
def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor:
|
||||
"""Process responses to stop at search operation or answer operation."""
|
||||
responses_str = self.tokenizer.batch_decode(
|
||||
responses,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
responses_str = [resp.split('</search>')[0] + '</search>'
|
||||
if '</search>' in resp
|
||||
else resp.split('</answer>')[0] + '</answer>'
|
||||
if '</answer>' in resp
|
||||
else resp
|
||||
for resp in responses_str]
|
||||
|
||||
if self.config.no_think_rl:
|
||||
raise ValueError('stop')
|
||||
# if no_think_rl is enabled, only keep action in the str
|
||||
actions, _ = self.env.postprocess_predictions(responses_str)
|
||||
responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
|
||||
print("RESPONSES:", responses_str)
|
||||
responses = self._batch_tokenize(responses_str)
|
||||
return responses, responses_str
|
||||
|
||||
def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
|
||||
"""Process next observations from environment."""
|
||||
|
||||
next_obs_ids = self.tokenizer(
|
||||
next_obs,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
add_special_tokens=False, # Prevents adding special tokens
|
||||
)['input_ids']
|
||||
|
||||
if next_obs_ids.shape[1] > self.config.max_obs_length:
|
||||
print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}")
|
||||
next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
|
||||
|
||||
return next_obs_ids
|
||||
|
||||
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor) -> Dict:
|
||||
"""Update rolling state with new responses and observations."""
|
||||
# Concatenate and handle padding
|
||||
new_input_ids = self.tensor_fn.concatenate_with_padding([
|
||||
rollings.batch['input_ids'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
])
|
||||
|
||||
# Create attention mask and position ids
|
||||
new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
|
||||
new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
|
||||
|
||||
# Cut to appropriate length
|
||||
effective_len = new_attention_mask.sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return DataProto.from_dict({
|
||||
'input_ids': new_input_ids[:, -max_len:],
|
||||
'position_ids': new_position_ids[:, -max_len:],
|
||||
'attention_mask': new_attention_mask[:, -max_len:]
|
||||
})
|
||||
|
||||
def _update_right_side(self, right_side: Dict,
|
||||
cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||
"""Update right side state."""
|
||||
if next_obs_ids != None:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
], pad_to_left=False)
|
||||
else:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
], pad_to_left=False)
|
||||
|
||||
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return {'responses': responses[:, :max_len]}
|
||||
|
||||
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||
"""
|
||||
Wrapper for generation that handles multi-GPU padding requirements.
|
||||
if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
if active_batch size is not divisible by num_gpus, pad with first sequence
|
||||
then remove padding from output
|
||||
"""
|
||||
num_gpus = self.config.num_gpus
|
||||
if num_gpus <= 1:
|
||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
|
||||
batch_size = active_batch.batch['input_ids'].shape[0]
|
||||
remainder = batch_size % num_gpus
|
||||
|
||||
if remainder == 0:
|
||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
|
||||
# Add padding sequences
|
||||
padding_size = num_gpus - remainder
|
||||
padded_batch = {}
|
||||
|
||||
for k, v in active_batch.batch.items():
|
||||
# Use first sequence as padding template
|
||||
pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
|
||||
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
||||
|
||||
padded_active_batch = DataProto.from_dict(padded_batch)
|
||||
|
||||
# Generate with padded batch
|
||||
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
||||
|
||||
# Remove padding from output
|
||||
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
|
||||
|
||||
# Handle meta_info if present
|
||||
if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
|
||||
trimmed_meta = {}
|
||||
for k, v in padded_output.meta_info.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
trimmed_meta[k] = v[:-padding_size]
|
||||
else:
|
||||
trimmed_meta[k] = v
|
||||
padded_output.meta_info = trimmed_meta
|
||||
|
||||
padded_output.batch = trimmed_batch
|
||||
return padded_output
|
||||
|
||||
def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:
|
||||
"""Run main LLM generation loop."""
|
||||
|
||||
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
rollings = gen_batch
|
||||
|
||||
# Main generation loop
|
||||
for step in range(self.config.max_turns):
|
||||
if not active_mask.sum():
|
||||
break
|
||||
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||
rollings.batch,
|
||||
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||
)
|
||||
|
||||
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||
rollings_active = DataProto.from_dict({
|
||||
k: v[active_mask] for k, v in rollings.batch.items()
|
||||
})
|
||||
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||
|
||||
meta_info = gen_output.meta_info
|
||||
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# Execute in environment and process observations
|
||||
next_obs, dones = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
|
||||
next_obs_ids = self._process_next_obs(next_obs)
|
||||
|
||||
# Update states
|
||||
rollings = self._update_rolling_state(
|
||||
rollings,
|
||||
responses_ids,
|
||||
next_obs_ids
|
||||
)
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
responses_ids,
|
||||
next_obs_ids
|
||||
)
|
||||
|
||||
# final LLM rollout
|
||||
if active_mask.sum():
|
||||
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||
rollings.batch,
|
||||
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||
)
|
||||
|
||||
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||
rollings_active = DataProto.from_dict({
|
||||
k: v[active_mask] for k, v in rollings.batch.items()
|
||||
})
|
||||
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||
|
||||
meta_info = gen_output.meta_info
|
||||
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# # Execute in environment and process observations
|
||||
_, dones = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
responses_ids,
|
||||
)
|
||||
|
||||
print("ACTIVE_TRAJ_NUM:", active_num_list)
|
||||
|
||||
return self._compose_final_output(original_left_side, original_right_side, meta_info)
|
||||
|
||||
def _compose_final_output(self, left_side: Dict,
|
||||
right_side: Dict,
|
||||
meta_info: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Compose final generation output."""
|
||||
final_output = right_side.copy()
|
||||
final_output['prompts'] = left_side['input_ids']
|
||||
|
||||
# Combine input IDs
|
||||
final_output['input_ids'] = torch.cat([
|
||||
left_side['input_ids'],
|
||||
right_side['responses']
|
||||
], dim=1)
|
||||
|
||||
# Create attention mask and position ids
|
||||
final_output['attention_mask'] = torch.cat([
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||
], dim=1)
|
||||
|
||||
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||
final_output['attention_mask']
|
||||
)
|
||||
|
||||
final_output = DataProto.from_dict(final_output)
|
||||
final_output.meta_info.update(meta_info)
|
||||
|
||||
return final_output
|
||||
|
||||
def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]:
|
||||
"""
|
||||
Execute predictions across multiple environments.
|
||||
NOTE: the function is the actual `step` function in the environment
|
||||
NOTE penalty_for_invalid is not included in observation shown to the LLM
|
||||
|
||||
Args:
|
||||
envs: List of environment instances
|
||||
predictions: List of action predictions
|
||||
pad_token: Token to use for padding
|
||||
|
||||
Returns:
|
||||
List of observation strings
|
||||
"""
|
||||
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||
next_obs, dones = [], []
|
||||
|
||||
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||
if do_search:
|
||||
search_results = self.batch_search(search_queries)
|
||||
assert len(search_results) == sum([1 for action in cur_actions if action == 'search'])
|
||||
else:
|
||||
search_results = [''] * sum([1 for action in cur_actions if action == 'search'])
|
||||
|
||||
for i, (action, active) in enumerate(zip(cur_actions, active_mask)):
|
||||
|
||||
if not active:
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
else:
|
||||
if action == 'answer':
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
elif action == 'search':
|
||||
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||
dones.append(0)
|
||||
else:
|
||||
next_obs.append(f'\nMy previous action is invalid. \
|
||||
If I want to search, I should put the query between <search> and </search>. \
|
||||
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||
dones.append(0)
|
||||
|
||||
assert len(search_results) == 0
|
||||
|
||||
return next_obs, dones
|
||||
|
||||
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||
"""
|
||||
Process (text-based) predictions from llm into actions and validity flags.
|
||||
|
||||
Args:
|
||||
predictions: List of raw predictions
|
||||
|
||||
Returns:
|
||||
Tuple of (actions list, validity flags list)
|
||||
"""
|
||||
actions = []
|
||||
contents = []
|
||||
|
||||
for prediction in predictions:
|
||||
if isinstance(prediction, str): # for llm output
|
||||
pattern = r'<(search|answer)>(.*?)</\1>'
|
||||
match = re.search(pattern, prediction, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(2).strip() # Return only the content inside the tags
|
||||
action = match.group(1)
|
||||
else:
|
||||
content = ''
|
||||
action = None
|
||||
else:
|
||||
raise ValueError(f"Invalid prediction type: {type(prediction)}")
|
||||
|
||||
actions.append(action)
|
||||
contents.append(content)
|
||||
|
||||
return actions, contents
|
||||
|
||||
def batch_search(self, queries: List[str] = None) -> str:
|
||||
"""
|
||||
Batchified search for queries.
|
||||
Args:
|
||||
queries: queries to call the search engine
|
||||
Returns:
|
||||
search results which is concatenated into a string
|
||||
"""
|
||||
results = self._batch_search(queries)['result']
|
||||
|
||||
return [self._passages2string(result) for result in results]
|
||||
|
||||
def _batch_search(self, queries):
|
||||
|
||||
payload = {
|
||||
"queries": queries,
|
||||
"topk": self.config.topk,
|
||||
"return_scores": True
|
||||
}
|
||||
|
||||
return requests.post(self.config.search_url, json=payload).json()
|
||||
|
||||
def _passages2string(self, retrieval_result):
|
||||
format_reference = ''
|
||||
for idx, doc_item in enumerate(retrieval_result):
|
||||
|
||||
content = doc_item['document']['contents']
|
||||
title = content.split("\n")[0]
|
||||
text = "\n".join(content.split("\n")[1:])
|
||||
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
|
||||
|
||||
return format_reference
|
||||
75
search_r1/llm_agent/tensor_helper.py
Normal file
75
search_r1/llm_agent/tensor_helper.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from typing import Dict, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class TensorConfig:
|
||||
pad_token_id: int
|
||||
max_prompt_length: int
|
||||
max_obs_length: int
|
||||
max_start_length: int
|
||||
|
||||
class TensorHelper:
|
||||
def __init__(self, config: TensorConfig):
|
||||
self.config = config
|
||||
|
||||
def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor],
|
||||
keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]:
|
||||
"""Cut tensors to their effective length based on attention mask."""
|
||||
effective_len = tensor_dict['attention_mask'].sum(dim=1).max()
|
||||
result = tensor_dict.copy()
|
||||
|
||||
for key in keys:
|
||||
if cut_left:
|
||||
result[key] = tensor_dict[key][:, -effective_len:]
|
||||
else:
|
||||
result[key] = tensor_dict[key][:, :effective_len]
|
||||
return result
|
||||
|
||||
def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert padding structure and return sorted tensor with indices."""
|
||||
mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id
|
||||
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||
return tensor.gather(1, sorted_indices), sorted_indices
|
||||
|
||||
def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from input ids."""
|
||||
return torch.where(input_ids != self.config.pad_token_id, 1, 0)
|
||||
|
||||
def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create position ids from attention mask."""
|
||||
return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask
|
||||
|
||||
def concatenate_with_padding(self, tensors: List[torch.Tensor],
|
||||
pad_to_left: bool = True) -> torch.Tensor:
|
||||
"""Concatenate tensors and handle padding."""
|
||||
concatenated = torch.cat(tensors, dim=1)
|
||||
padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left)
|
||||
return padded_tensor
|
||||
|
||||
def _example_level_pad(self, responses: torch.Tensor,
|
||||
responses_str: List[str],
|
||||
active_mask: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Pad responses for non-active examples with pad tokens.
|
||||
"""
|
||||
assert active_mask.sum() == responses.shape[0]
|
||||
# Create masked responses tensor
|
||||
batch_size = active_mask.shape[0]
|
||||
seq_len = responses.shape[1]
|
||||
padded_responses = torch.full(
|
||||
(batch_size, seq_len), self.config.pad_token_id,
|
||||
dtype=responses.dtype, device=responses.device
|
||||
)
|
||||
padded_responses[active_mask] = responses
|
||||
|
||||
# Create masked response strings
|
||||
padded_responses_str = [""] * batch_size
|
||||
|
||||
s = 0
|
||||
for i, is_active in enumerate(active_mask):
|
||||
if is_active:
|
||||
padded_responses_str[i] = responses_str[s]
|
||||
s += 1
|
||||
|
||||
return padded_responses, padded_responses_str
|
||||
17
search_r1/search/build_index.sh
Normal file
17
search_r1/search/build_index.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
corpus_file=/your/corpus/jsonl/file # jsonl
|
||||
save_dir=/the/path/to/save/index
|
||||
retriever_name=e5 # this is for indexing naming
|
||||
retriever_model=intfloat/e5-base-v2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python index_builder.py \
|
||||
--retrieval_method $retriever_name \
|
||||
--model_path $retriever_model \
|
||||
--corpus_path $corpus_file \
|
||||
--save_dir $save_dir \
|
||||
--use_fp16 \
|
||||
--max_length 256 \
|
||||
--batch_size 512 \
|
||||
--pooling_method mean \
|
||||
--faiss_type Flat \
|
||||
--save_embedding
|
||||
348
search_r1/search/index_builder.py
Normal file
348
search_r1/search/index_builder.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import os
|
||||
import faiss
|
||||
import json
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import cast, List, Dict
|
||||
import shutil
|
||||
import subprocess
|
||||
import argparse
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
# from LongRAG.retriever.utils import load_model, load_corpus, pooling
|
||||
import datasets
|
||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str,
|
||||
use_fp16: bool = False
|
||||
):
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model = model.half()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pooling(
|
||||
pooler_output,
|
||||
last_hidden_state,
|
||||
attention_mask = None,
|
||||
pooling_method = "mean"
|
||||
):
|
||||
if pooling_method == "mean":
|
||||
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif pooling_method == "cls":
|
||||
return last_hidden_state[:, 0]
|
||||
elif pooling_method == "pooler":
|
||||
return pooler_output
|
||||
else:
|
||||
raise NotImplementedError("Pooling method not implemented!")
|
||||
|
||||
|
||||
def load_corpus(corpus_path: str):
|
||||
corpus = datasets.load_dataset(
|
||||
'json',
|
||||
data_files=corpus_path,
|
||||
split="train",
|
||||
num_proc=4)
|
||||
return corpus
|
||||
|
||||
|
||||
class Index_Builder:
|
||||
r"""A tool class used to build an index used in retrieval.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_method,
|
||||
model_path,
|
||||
corpus_path,
|
||||
save_dir,
|
||||
max_length,
|
||||
batch_size,
|
||||
use_fp16,
|
||||
pooling_method,
|
||||
faiss_type=None,
|
||||
embedding_path=None,
|
||||
save_embedding=False,
|
||||
faiss_gpu=False
|
||||
):
|
||||
|
||||
self.retrieval_method = retrieval_method.lower()
|
||||
self.model_path = model_path
|
||||
self.corpus_path = corpus_path
|
||||
self.save_dir = save_dir
|
||||
self.max_length = max_length
|
||||
self.batch_size = batch_size
|
||||
self.use_fp16 = use_fp16
|
||||
self.pooling_method = pooling_method
|
||||
self.faiss_type = faiss_type if faiss_type is not None else 'Flat'
|
||||
self.embedding_path = embedding_path
|
||||
self.save_embedding = save_embedding
|
||||
self.faiss_gpu = faiss_gpu
|
||||
|
||||
self.gpu_num = torch.cuda.device_count()
|
||||
# prepare save dir
|
||||
print(self.save_dir)
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
else:
|
||||
if not self._check_dir(self.save_dir):
|
||||
warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning)
|
||||
|
||||
self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index")
|
||||
|
||||
self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap")
|
||||
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
|
||||
print("Finish loading...")
|
||||
@staticmethod
|
||||
def _check_dir(dir_path):
|
||||
r"""Check if the dir path exists and if there is content.
|
||||
|
||||
"""
|
||||
|
||||
if os.path.isdir(dir_path):
|
||||
if len(os.listdir(dir_path)) > 0:
|
||||
return False
|
||||
else:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
return True
|
||||
|
||||
def build_index(self):
|
||||
r"""Constructing different indexes based on selective retrieval method.
|
||||
|
||||
"""
|
||||
if self.retrieval_method == "bm25":
|
||||
self.build_bm25_index()
|
||||
else:
|
||||
self.build_dense_index()
|
||||
|
||||
def build_bm25_index(self):
|
||||
"""Building BM25 index based on Pyserini library.
|
||||
|
||||
Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
|
||||
"""
|
||||
|
||||
# to use pyserini pipeline, we first need to place jsonl file in the folder
|
||||
self.save_dir = os.path.join(self.save_dir, "bm25")
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
temp_dir = self.save_dir + "/temp"
|
||||
temp_file_path = temp_dir + "/temp.jsonl"
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# if self.have_contents:
|
||||
# shutil.copyfile(self.corpus_path, temp_file_path)
|
||||
# else:
|
||||
# with open(temp_file_path, "w") as f:
|
||||
# for item in self.corpus:
|
||||
# f.write(json.dumps(item) + "\n")
|
||||
shutil.copyfile(self.corpus_path, temp_file_path)
|
||||
|
||||
print("Start building bm25 index...")
|
||||
pyserini_args = ["--collection", "JsonCollection",
|
||||
"--input", temp_dir,
|
||||
"--index", self.save_dir,
|
||||
"--generator", "DefaultLuceneDocumentGenerator",
|
||||
"--threads", "1"]
|
||||
|
||||
subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args)
|
||||
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
print("Finish!")
|
||||
|
||||
def _load_embedding(self, embedding_path, corpus_size, hidden_size):
|
||||
all_embeddings = np.memmap(
|
||||
embedding_path,
|
||||
mode="r",
|
||||
dtype=np.float32
|
||||
).reshape(corpus_size, hidden_size)
|
||||
return all_embeddings
|
||||
|
||||
def _save_embedding(self, all_embeddings):
|
||||
memmap = np.memmap(
|
||||
self.embedding_save_path,
|
||||
shape=all_embeddings.shape,
|
||||
mode="w+",
|
||||
dtype=all_embeddings.dtype
|
||||
)
|
||||
length = all_embeddings.shape[0]
|
||||
# add in batch
|
||||
save_batch_size = 10000
|
||||
if length > save_batch_size:
|
||||
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
||||
j = min(i + save_batch_size, length)
|
||||
memmap[i: j] = all_embeddings[i: j]
|
||||
else:
|
||||
memmap[:] = all_embeddings
|
||||
|
||||
def encode_all(self):
|
||||
if self.gpu_num > 1:
|
||||
print("Use multi gpu!")
|
||||
self.encoder = torch.nn.DataParallel(self.encoder)
|
||||
self.batch_size = self.batch_size * self.gpu_num
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'):
|
||||
|
||||
batch_data_title = self.corpus[start_idx:start_idx+self.batch_size]['title']
|
||||
batch_data_text = self.corpus[start_idx:start_idx+self.batch_size]['text']
|
||||
batch_data = ['"' + title + '"\n' + text for title, text in zip(batch_data_title, batch_data_text)]
|
||||
|
||||
if self.retrieval_method == "e5":
|
||||
batch_data = [f"passage: {doc}" for doc in batch_data]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
batch_data,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.max_length,
|
||||
).to('cuda')
|
||||
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
#TODO: support encoder-only T5 model
|
||||
if "T5" in type(self.encoder).__name__:
|
||||
# T5-based retrieval model
|
||||
decoder_input_ids = torch.zeros(
|
||||
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||
).to(inputs['input_ids'].device)
|
||||
output = self.encoder(
|
||||
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||
)
|
||||
embeddings = output.last_hidden_state[:, 0, :]
|
||||
|
||||
else:
|
||||
output = self.encoder(**inputs, return_dict=True)
|
||||
embeddings = pooling(output.pooler_output,
|
||||
output.last_hidden_state,
|
||||
inputs['attention_mask'],
|
||||
self.pooling_method)
|
||||
if "dpr" not in self.retrieval_method:
|
||||
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
||||
|
||||
embeddings = cast(torch.Tensor, embeddings)
|
||||
embeddings = embeddings.detach().cpu().numpy()
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
||||
all_embeddings = all_embeddings.astype(np.float32)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def build_dense_index(self):
|
||||
"""Obtain the representation of documents based on the embedding model(BERT-based) and
|
||||
construct a faiss index.
|
||||
"""
|
||||
|
||||
if os.path.exists(self.index_save_path):
|
||||
print("The index file already exists and will be overwritten.")
|
||||
|
||||
self.encoder, self.tokenizer = load_model(model_path = self.model_path,
|
||||
use_fp16 = self.use_fp16)
|
||||
if self.embedding_path is not None:
|
||||
hidden_size = self.encoder.config.hidden_size
|
||||
corpus_size = len(self.corpus)
|
||||
all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
|
||||
else:
|
||||
all_embeddings = self.encode_all()
|
||||
if self.save_embedding:
|
||||
self._save_embedding(all_embeddings)
|
||||
del self.corpus
|
||||
|
||||
# build index
|
||||
print("Creating index")
|
||||
dim = all_embeddings.shape[-1]
|
||||
faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
if self.faiss_gpu:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.useFloat16 = True
|
||||
co.shard = True
|
||||
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
||||
if not faiss_index.is_trained:
|
||||
faiss_index.train(all_embeddings)
|
||||
faiss_index.add(all_embeddings)
|
||||
faiss_index = faiss.index_gpu_to_cpu(faiss_index)
|
||||
else:
|
||||
if not faiss_index.is_trained:
|
||||
faiss_index.train(all_embeddings)
|
||||
faiss_index.add(all_embeddings)
|
||||
|
||||
faiss.write_index(faiss_index, self.index_save_path)
|
||||
print("Finish!")
|
||||
|
||||
|
||||
MODEL2POOLING = {
|
||||
"e5": "mean",
|
||||
"bge": "cls",
|
||||
"contriever": "mean",
|
||||
'jina': 'mean'
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description = "Creating index.")
|
||||
|
||||
# Basic parameters
|
||||
parser.add_argument('--retrieval_method', type=str)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--corpus_path', type=str)
|
||||
parser.add_argument('--save_dir', default= 'indexes/',type=str)
|
||||
|
||||
# Parameters for building dense index
|
||||
parser.add_argument('--max_length', type=int, default=180)
|
||||
parser.add_argument('--batch_size', type=int, default=512)
|
||||
parser.add_argument('--use_fp16', default=False, action='store_true')
|
||||
parser.add_argument('--pooling_method', type=str, default=None)
|
||||
parser.add_argument('--faiss_type',default=None,type=str)
|
||||
parser.add_argument('--embedding_path', default=None, type=str)
|
||||
parser.add_argument('--save_embedding', action='store_true', default=False)
|
||||
parser.add_argument('--faiss_gpu', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pooling_method is None:
|
||||
pooling_method = 'mean'
|
||||
for k,v in MODEL2POOLING.items():
|
||||
if k in args.retrieval_method.lower():
|
||||
pooling_method = v
|
||||
break
|
||||
else:
|
||||
if args.pooling_method not in ['mean','cls','pooler']:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
pooling_method = args.pooling_method
|
||||
|
||||
|
||||
index_builder = Index_Builder(
|
||||
retrieval_method = args.retrieval_method,
|
||||
model_path = args.model_path,
|
||||
corpus_path = args.corpus_path,
|
||||
save_dir = args.save_dir,
|
||||
max_length = args.max_length,
|
||||
batch_size = args.batch_size,
|
||||
use_fp16 = args.use_fp16,
|
||||
pooling_method = pooling_method,
|
||||
faiss_type = args.faiss_type,
|
||||
embedding_path = args.embedding_path,
|
||||
save_embedding = args.save_embedding,
|
||||
faiss_gpu = args.faiss_gpu
|
||||
)
|
||||
index_builder.build_index()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
368
search_r1/search/retrieval.py
Normal file
368
search_r1/search/retrieval.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Dict
|
||||
import functools
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
import faiss
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||
import argparse
|
||||
import datasets
|
||||
|
||||
|
||||
def load_corpus(corpus_path: str):
|
||||
corpus = datasets.load_dataset(
|
||||
'json',
|
||||
data_files=corpus_path,
|
||||
split="train",
|
||||
num_proc=4)
|
||||
return corpus
|
||||
|
||||
|
||||
def read_jsonl(file_path):
|
||||
data = []
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
readin = f.readlines()
|
||||
for line in readin:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
def load_docs(corpus, doc_idxs):
|
||||
results = [corpus[int(idx)] for idx in doc_idxs]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str,
|
||||
use_fp16: bool = False
|
||||
):
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model = model.half()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pooling(
|
||||
pooler_output,
|
||||
last_hidden_state,
|
||||
attention_mask = None,
|
||||
pooling_method = "mean"
|
||||
):
|
||||
if pooling_method == "mean":
|
||||
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif pooling_method == "cls":
|
||||
return last_hidden_state[:, 0]
|
||||
elif pooling_method == "pooler":
|
||||
return pooler_output
|
||||
else:
|
||||
raise NotImplementedError("Pooling method not implemented!")
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.pooling_method = pooling_method
|
||||
self.max_length = max_length
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
self.model, self.tokenizer = load_model(model_path=model_path,
|
||||
use_fp16=use_fp16)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||
# processing query for different encoders
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
|
||||
if "e5" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"query: {query}" for query in query_list]
|
||||
else:
|
||||
query_list = [f"passage: {query}" for query in query_list]
|
||||
|
||||
if "bge" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
|
||||
|
||||
inputs = self.tokenizer(query_list,
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
if "T5" in type(self.model).__name__:
|
||||
# T5-based retrieval model
|
||||
decoder_input_ids = torch.zeros(
|
||||
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||
).to(inputs['input_ids'].device)
|
||||
output = self.model(
|
||||
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||
)
|
||||
query_emb = output.last_hidden_state[:, 0, :]
|
||||
|
||||
else:
|
||||
output = self.model(**inputs, return_dict=True)
|
||||
query_emb = pooling(output.pooler_output,
|
||||
output.last_hidden_state,
|
||||
inputs['attention_mask'],
|
||||
self.pooling_method)
|
||||
if "dpr" not in self.model_name.lower():
|
||||
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
||||
|
||||
query_emb = query_emb.detach().cpu().numpy()
|
||||
query_emb = query_emb.astype(np.float32, order="C")
|
||||
return query_emb
|
||||
|
||||
|
||||
class BaseRetriever:
|
||||
"""Base object for all retrievers."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.retrieval_method = config.retrieval_method
|
||||
self.topk = config.retrieval_topk
|
||||
|
||||
self.index_path = config.index_path
|
||||
self.corpus_path = config.corpus_path
|
||||
|
||||
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
|
||||
|
||||
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
|
||||
r"""Retrieve topk relevant documents in corpus.
|
||||
Return:
|
||||
list: contains information related to the document, including:
|
||||
contents: used for building index
|
||||
title: (if provided)
|
||||
text: (if provided)
|
||||
"""
|
||||
pass
|
||||
|
||||
def _batch_search(self, query_list, num, return_score):
|
||||
pass
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return self._search(*args, **kwargs)
|
||||
|
||||
def batch_search(self, *args, **kwargs):
|
||||
return self._batch_search(*args, **kwargs)
|
||||
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
r"""BM25 retriever based on pre-built pyserini index."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
self.searcher = LuceneSearcher(self.index_path)
|
||||
self.contain_doc = self._check_contain_doc()
|
||||
if not self.contain_doc:
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.max_process_num = 8
|
||||
|
||||
def _check_contain_doc(self):
|
||||
r"""Check if the index contains document content
|
||||
"""
|
||||
return self.searcher.doc(0).raw() is not None
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
|
||||
if num is None:
|
||||
num = self.topk
|
||||
|
||||
hits = self.searcher.search(query, num)
|
||||
if len(hits) < 1:
|
||||
if return_score:
|
||||
return [],[]
|
||||
else:
|
||||
return []
|
||||
|
||||
scores = [hit.score for hit in hits]
|
||||
if len(hits) < num:
|
||||
warnings.warn('Not enough documents retrieved!')
|
||||
else:
|
||||
hits = hits[:num]
|
||||
|
||||
if self.contain_doc:
|
||||
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
|
||||
results = [{'title': content.split("\n")[0].strip("\""),
|
||||
'text': "\n".join(content.split("\n")[1:]),
|
||||
'contents': content} for content in all_contents]
|
||||
else:
|
||||
results = load_docs(self.corpus, [hit.docid for hit in hits])
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list, num: int = None, return_score = False):
|
||||
# TODO: modify batch method
|
||||
results = []
|
||||
scores = []
|
||||
for query in query_list:
|
||||
item_result, item_score = self._search(query, num,True)
|
||||
results.append(item_result)
|
||||
scores.append(item_score)
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def get_available_gpu_memory():
|
||||
memory_info = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
total_memory = torch.cuda.get_device_properties(i).total_memory
|
||||
allocated_memory = torch.cuda.memory_allocated(i)
|
||||
free_memory = total_memory - allocated_memory
|
||||
memory_info.append((i, free_memory / 1e9)) # Convert to GB
|
||||
return memory_info
|
||||
|
||||
|
||||
class DenseRetriever(BaseRetriever):
|
||||
r"""Dense retriever based on pre-built faiss index."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
super().__init__(config)
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
if config.faiss_gpu:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.useFloat16 = True
|
||||
co.shard = True
|
||||
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
|
||||
# self.index = faiss.index_cpu_to_all_gpus(self.index)
|
||||
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.encoder = Encoder(
|
||||
model_name = self.retrieval_method,
|
||||
model_path = config.retrieval_model_path,
|
||||
pooling_method = config.retrieval_pooling_method,
|
||||
max_length = config.retrieval_query_max_length,
|
||||
use_fp16 = config.retrieval_use_fp16
|
||||
)
|
||||
self.topk = config.retrieval_topk
|
||||
self.batch_size = self.config.retrieval_batch_size
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score = False):
|
||||
if num is None:
|
||||
num = self.topk
|
||||
query_emb = self.encoder.encode(query)
|
||||
scores, idxs = self.index.search(query_emb, k=num)
|
||||
idxs = idxs[0]
|
||||
scores = scores[0]
|
||||
|
||||
results = load_docs(self.corpus, idxs)
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
if num is None:
|
||||
num = self.topk
|
||||
|
||||
batch_size = self.batch_size
|
||||
|
||||
results = []
|
||||
scores = []
|
||||
|
||||
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
|
||||
query_batch = query_list[start_idx:start_idx + batch_size]
|
||||
|
||||
# from time import time
|
||||
# a = time()
|
||||
batch_emb = self.encoder.encode(query_batch)
|
||||
# b = time()
|
||||
# print(f'################### encode time {b-a} #####################')
|
||||
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
|
||||
batch_scores = batch_scores.tolist()
|
||||
batch_idxs = batch_idxs.tolist()
|
||||
# print(f'################### search time {time()-b} #####################')
|
||||
# exit()
|
||||
|
||||
flat_idxs = sum(batch_idxs, [])
|
||||
batch_results = load_docs(self.corpus, flat_idxs)
|
||||
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
|
||||
|
||||
scores.extend(batch_scores)
|
||||
results.extend(batch_results)
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def get_retriever(config):
|
||||
r"""Automatically select retriever class based on config's retrieval method
|
||||
|
||||
Args:
|
||||
config (dict): configuration with 'retrieval_method' key
|
||||
|
||||
Returns:
|
||||
Retriever: retriever instance
|
||||
"""
|
||||
if config.retrieval_method == "bm25":
|
||||
return BM25Retriever(config)
|
||||
else:
|
||||
return DenseRetriever(config)
|
||||
|
||||
|
||||
def get_dataset(config):
|
||||
"""Load dataset from config."""
|
||||
|
||||
split_path = os.path.join(config.dataset_path, f'{config.data_split}.jsonl')
|
||||
return read_jsonl(split_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description = "Retrieval")
|
||||
|
||||
# Basic parameters
|
||||
parser.add_argument('--retrieval_method', type=str)
|
||||
parser.add_argument('--retrieval_topk', type=int, default=10)
|
||||
parser.add_argument('--index_path', type=str, default=None)
|
||||
parser.add_argument('--corpus_path', type=str)
|
||||
parser.add_argument('--dataset_path', default=None, type=str)
|
||||
|
||||
parser.add_argument('--faiss_gpu', default=True, type=bool)
|
||||
parser.add_argument('--data_split', default="train", type=str)
|
||||
|
||||
parser.add_argument('--retrieval_model_path', type=str, default=None)
|
||||
parser.add_argument('--retrieval_pooling_method', default='mean', type=str)
|
||||
parser.add_argument('--retrieval_query_max_length', default=256, type=str)
|
||||
parser.add_argument('--retrieval_use_fp16', action='store_true', default=False)
|
||||
parser.add_argument('--retrieval_batch_size', default=512, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.index_path = os.path.join(args.index_path, f'{args.retrieval_method}_Flat.index') if args.retrieval_method != 'bm25' else os.path.join(args.index_path, 'bm25')
|
||||
|
||||
# load dataset
|
||||
all_split = get_dataset(args)
|
||||
|
||||
input_query = [sample['question'] for sample in all_split[:512]]
|
||||
|
||||
# initialize the retriever and conduct retrieval
|
||||
retriever = get_retriever(args)
|
||||
print('Start Retrieving ...')
|
||||
results, scores = retriever.batch_search(input_query, return_score=True)
|
||||
|
||||
# from IPython import embed
|
||||
# embed()
|
||||
25
search_r1/search/retrieval.sh
Normal file
25
search_r1/search/retrieval.sh
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
DATA_NAME=nq
|
||||
|
||||
DATASET_PATH="/home/peterjin/mnt/data/$DATA_NAME"
|
||||
|
||||
SPLIT='test'
|
||||
TOPK=3
|
||||
|
||||
INDEX_PATH=/home/peterjin/mnt/index/wiki-18
|
||||
CORPUS_PATH=/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl
|
||||
SAVE_NAME=e5_${TOPK}_wiki18.json
|
||||
|
||||
# INDEX_PATH=/home/peterjin/rm_retrieval_corpus/index/wiki-21
|
||||
# CORPUS_PATH=/home/peterjin/rm_retrieval_corpus/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
|
||||
# SAVE_NAME=e5_${TOPK}_wiki21.json
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python retrieval.py --retrieval_method e5 \
|
||||
--retrieval_topk $TOPK \
|
||||
--index_path $INDEX_PATH \
|
||||
--corpus_path $CORPUS_PATH \
|
||||
--dataset_path $DATASET_PATH \
|
||||
--data_split $SPLIT \
|
||||
--retrieval_model_path "intfloat/e5-base-v2" \
|
||||
--retrieval_pooling_method "mean" \
|
||||
--retrieval_batch_size 512 \
|
||||
23
search_r1/search/retrieval_request.py
Normal file
23
search_r1/search/retrieval_request.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import requests
|
||||
|
||||
# URL for your local FastAPI server
|
||||
url = "http://127.0.0.1:8000/retrieve"
|
||||
|
||||
# Example payload
|
||||
payload = {
|
||||
"queries": ["What is the capital of France?", "Explain neural networks."] * 200,
|
||||
"topk": 5,
|
||||
"return_scores": True
|
||||
}
|
||||
|
||||
# Send POST request
|
||||
response = requests.post(url, json=payload)
|
||||
|
||||
# Raise an exception if the request failed
|
||||
response.raise_for_status()
|
||||
|
||||
# Get the JSON response
|
||||
retrieved_data = response.json()
|
||||
|
||||
print("Response from server:")
|
||||
print(retrieved_data)
|
||||
382
search_r1/search/retrieval_server.py
Normal file
382
search_r1/search/retrieval_server.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Dict, Optional
|
||||
import argparse
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||
from tqdm import tqdm
|
||||
import datasets
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def load_corpus(corpus_path: str):
|
||||
corpus = datasets.load_dataset(
|
||||
'json',
|
||||
data_files=corpus_path,
|
||||
split="train",
|
||||
num_proc=4
|
||||
)
|
||||
return corpus
|
||||
|
||||
def read_jsonl(file_path):
|
||||
data = []
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
def load_docs(corpus, doc_idxs):
|
||||
results = [corpus[int(idx)] for idx in doc_idxs]
|
||||
return results
|
||||
|
||||
def load_model(model_path: str, use_fp16: bool = False):
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model = model.half()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
def pooling(
|
||||
pooler_output,
|
||||
last_hidden_state,
|
||||
attention_mask = None,
|
||||
pooling_method = "mean"
|
||||
):
|
||||
if pooling_method == "mean":
|
||||
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif pooling_method == "cls":
|
||||
return last_hidden_state[:, 0]
|
||||
elif pooling_method == "pooler":
|
||||
return pooler_output
|
||||
else:
|
||||
raise NotImplementedError("Pooling method not implemented!")
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.pooling_method = pooling_method
|
||||
self.max_length = max_length
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||
# processing query for different encoders
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
|
||||
if "e5" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"query: {query}" for query in query_list]
|
||||
else:
|
||||
query_list = [f"passage: {query}" for query in query_list]
|
||||
|
||||
if "bge" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
|
||||
|
||||
inputs = self.tokenizer(query_list,
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
if "T5" in type(self.model).__name__:
|
||||
# T5-based retrieval model
|
||||
decoder_input_ids = torch.zeros(
|
||||
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||
).to(inputs['input_ids'].device)
|
||||
output = self.model(
|
||||
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||
)
|
||||
query_emb = output.last_hidden_state[:, 0, :]
|
||||
else:
|
||||
output = self.model(**inputs, return_dict=True)
|
||||
query_emb = pooling(output.pooler_output,
|
||||
output.last_hidden_state,
|
||||
inputs['attention_mask'],
|
||||
self.pooling_method)
|
||||
if "dpr" not in self.model_name.lower():
|
||||
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
||||
|
||||
query_emb = query_emb.detach().cpu().numpy()
|
||||
query_emb = query_emb.astype(np.float32, order="C")
|
||||
return query_emb
|
||||
|
||||
class BaseRetriever:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.retrieval_method = config.retrieval_method
|
||||
self.topk = config.retrieval_topk
|
||||
|
||||
self.index_path = config.index_path
|
||||
self.corpus_path = config.corpus_path
|
||||
|
||||
def _search(self, query: str, num: int, return_score: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def search(self, query: str, num: int = None, return_score: bool = False):
|
||||
return self._search(query, num, return_score)
|
||||
|
||||
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
return self._batch_search(query_list, num, return_score)
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
self.searcher = LuceneSearcher(self.index_path)
|
||||
self.contain_doc = self._check_contain_doc()
|
||||
if not self.contain_doc:
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.max_process_num = 8
|
||||
|
||||
def _check_contain_doc(self):
|
||||
return self.searcher.doc(0).raw() is not None
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score: bool = False):
|
||||
if num is None:
|
||||
num = self.topk
|
||||
hits = self.searcher.search(query, num)
|
||||
if len(hits) < 1:
|
||||
if return_score:
|
||||
return [], []
|
||||
else:
|
||||
return []
|
||||
scores = [hit.score for hit in hits]
|
||||
if len(hits) < num:
|
||||
warnings.warn('Not enough documents retrieved!')
|
||||
else:
|
||||
hits = hits[:num]
|
||||
|
||||
if self.contain_doc:
|
||||
all_contents = [
|
||||
json.loads(self.searcher.doc(hit.docid).raw())['contents']
|
||||
for hit in hits
|
||||
]
|
||||
results = [
|
||||
{
|
||||
'title': content.split("\n")[0].strip("\""),
|
||||
'text': "\n".join(content.split("\n")[1:]),
|
||||
'contents': content
|
||||
}
|
||||
for content in all_contents
|
||||
]
|
||||
else:
|
||||
results = load_docs(self.corpus, [hit.docid for hit in hits])
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
results = []
|
||||
scores = []
|
||||
for query in query_list:
|
||||
item_result, item_score = self._search(query, num, True)
|
||||
results.append(item_result)
|
||||
scores.append(item_score)
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
class DenseRetriever(BaseRetriever):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
if config.faiss_gpu:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.useFloat16 = True
|
||||
co.shard = True
|
||||
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
|
||||
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.encoder = Encoder(
|
||||
model_name = self.retrieval_method,
|
||||
model_path = config.retrieval_model_path,
|
||||
pooling_method = config.retrieval_pooling_method,
|
||||
max_length = config.retrieval_query_max_length,
|
||||
use_fp16 = config.retrieval_use_fp16
|
||||
)
|
||||
self.topk = config.retrieval_topk
|
||||
self.batch_size = config.retrieval_batch_size
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score: bool = False):
|
||||
if num is None:
|
||||
num = self.topk
|
||||
query_emb = self.encoder.encode(query)
|
||||
scores, idxs = self.index.search(query_emb, k=num)
|
||||
idxs = idxs[0]
|
||||
scores = scores[0]
|
||||
results = load_docs(self.corpus, idxs)
|
||||
if return_score:
|
||||
return results, scores.tolist()
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
if num is None:
|
||||
num = self.topk
|
||||
|
||||
results = []
|
||||
scores = []
|
||||
for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
|
||||
query_batch = query_list[start_idx:start_idx + self.batch_size]
|
||||
batch_emb = self.encoder.encode(query_batch)
|
||||
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
|
||||
batch_scores = batch_scores.tolist()
|
||||
batch_idxs = batch_idxs.tolist()
|
||||
|
||||
# load_docs is not vectorized, but is a python list approach
|
||||
flat_idxs = sum(batch_idxs, [])
|
||||
batch_results = load_docs(self.corpus, flat_idxs)
|
||||
# chunk them back
|
||||
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
|
||||
|
||||
results.extend(batch_results)
|
||||
scores.extend(batch_scores)
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def get_retriever(config):
|
||||
if config.retrieval_method == "bm25":
|
||||
return BM25Retriever(config)
|
||||
else:
|
||||
return DenseRetriever(config)
|
||||
|
||||
|
||||
#####################################
|
||||
# FastAPI server below
|
||||
#####################################
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Minimal config class (simulating your argparse)
|
||||
Replace this with your real arguments or load them dynamically.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_method: str = "bm25",
|
||||
retrieval_topk: int = 10,
|
||||
index_path: str = "./index/bm25",
|
||||
corpus_path: str = "./data/corpus.jsonl",
|
||||
dataset_path: str = "./data",
|
||||
data_split: str = "train",
|
||||
faiss_gpu: bool = True,
|
||||
retrieval_model_path: str = "./model",
|
||||
retrieval_pooling_method: str = "mean",
|
||||
retrieval_query_max_length: int = 256,
|
||||
retrieval_use_fp16: bool = False,
|
||||
retrieval_batch_size: int = 128
|
||||
):
|
||||
self.retrieval_method = retrieval_method
|
||||
self.retrieval_topk = retrieval_topk
|
||||
self.index_path = index_path
|
||||
self.corpus_path = corpus_path
|
||||
self.dataset_path = dataset_path
|
||||
self.data_split = data_split
|
||||
self.faiss_gpu = faiss_gpu
|
||||
self.retrieval_model_path = retrieval_model_path
|
||||
self.retrieval_pooling_method = retrieval_pooling_method
|
||||
self.retrieval_query_max_length = retrieval_query_max_length
|
||||
self.retrieval_use_fp16 = retrieval_use_fp16
|
||||
self.retrieval_batch_size = retrieval_batch_size
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
queries: List[str]
|
||||
topk: Optional[int] = None
|
||||
return_scores: bool = False
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 1) Build a config (could also parse from arguments).
|
||||
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||
config = Config(
|
||||
retrieval_method = "e5", # or "dense"
|
||||
index_path=args.index_path,
|
||||
corpus_path=args.corpus_path,
|
||||
retrieval_topk=args.topk,
|
||||
faiss_gpu=True,
|
||||
retrieval_model_path=args.retriever_model,
|
||||
retrieval_pooling_method="mean",
|
||||
retrieval_query_max_length=256,
|
||||
retrieval_use_fp16=True,
|
||||
retrieval_batch_size=512,
|
||||
)
|
||||
|
||||
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||
retriever = get_retriever(config)
|
||||
|
||||
@app.post("/retrieve")
|
||||
def retrieve_endpoint(request: QueryRequest):
|
||||
"""
|
||||
Endpoint that accepts queries and performs retrieval.
|
||||
Input format:
|
||||
{
|
||||
"queries": ["What is Python?", "Tell me about neural networks."],
|
||||
"topk": 3,
|
||||
"return_scores": true
|
||||
}
|
||||
"""
|
||||
if not request.topk:
|
||||
request.topk = config.retrieval_topk # fallback to default
|
||||
|
||||
# Perform batch retrieval
|
||||
results, scores = retriever.batch_search(
|
||||
query_list=request.queries,
|
||||
num=request.topk,
|
||||
return_score=request.return_scores
|
||||
)
|
||||
|
||||
# Format response
|
||||
resp = []
|
||||
for i, single_result in enumerate(results):
|
||||
if request.return_scores:
|
||||
# If scores are returned, combine them with results
|
||||
combined = []
|
||||
for doc, score in zip(single_result, scores[i]):
|
||||
combined.append({"document": doc, "score": score})
|
||||
resp.append(combined)
|
||||
else:
|
||||
resp.append(single_result)
|
||||
return {"result": resp}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user