From 7d6a15bfc53cc06344a72e4fa39967e8eb5b53c4 Mon Sep 17 00:00:00 2001 From: PeterGriffinJin Date: Tue, 29 Apr 2025 16:02:26 +0000 Subject: [PATCH] remove unuseful file --- verl/workers/retriever_workers.py | 383 ------------------------------ 1 file changed, 383 deletions(-) delete mode 100644 verl/workers/retriever_workers.py diff --git a/verl/workers/retriever_workers.py b/verl/workers/retriever_workers.py deleted file mode 100644 index 5bb0952..0000000 --- a/verl/workers/retriever_workers.py +++ /dev/null @@ -1,383 +0,0 @@ -import os -import torch -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import register, Dispatch -# from search_r1.env.search.retrieval import get_retriever - -import json -import os -import warnings -from typing import List, Dict -import functools -from tqdm import tqdm -from multiprocessing import Pool -import faiss -import torch -import numpy as np -from transformers import AutoConfig, AutoTokenizer, AutoModel -import argparse -import datasets - - -def load_corpus(corpus_path: str): - corpus = datasets.load_dataset( - 'json', - data_files=corpus_path, - split="train", - num_proc=4) - return corpus - - -def read_jsonl(file_path): - data = [] - - with open(file_path, "r") as f: - readin = f.readlines() - for line in readin: - data.append(json.loads(line)) - return data - - -def load_docs(corpus, doc_idxs): - results = [corpus[int(idx)] for idx in doc_idxs] - - return results - - -def load_model( - model_path: str, - use_fp16: bool = False - ): - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - model = AutoModel.from_pretrained(model_path, trust_remote_code=True) - model.eval() - model.cuda() - if use_fp16: - model = model.half() - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) - - return model, tokenizer - - -def pooling( - pooler_output, - last_hidden_state, - attention_mask = None, - pooling_method = "mean" - ): - if pooling_method == "mean": - last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - elif pooling_method == "cls": - return last_hidden_state[:, 0] - elif pooling_method == "pooler": - return pooler_output - else: - raise NotImplementedError("Pooling method not implemented!") - - -class Encoder: - def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16): - self.model_name = model_name - self.model_path = model_path - self.pooling_method = pooling_method - self.max_length = max_length - self.use_fp16 = use_fp16 - - self.model, self.tokenizer = load_model(model_path=model_path, - use_fp16=use_fp16) - - @torch.no_grad() - def encode(self, query_list: List[str], is_query=True) -> np.ndarray: - # processing query for different encoders - if isinstance(query_list, str): - query_list = [query_list] - - if "e5" in self.model_name.lower(): - if is_query: - query_list = [f"query: {query}" for query in query_list] - else: - query_list = [f"passage: {query}" for query in query_list] - - if "bge" in self.model_name.lower(): - if is_query: - query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list] - - inputs = self.tokenizer(query_list, - max_length=self.max_length, - padding=True, - truncation=True, - return_tensors="pt" - ) - inputs = {k: v.cuda() for k, v in inputs.items()} - - if "T5" in type(self.model).__name__: - # T5-based retrieval model - decoder_input_ids = torch.zeros( - (inputs['input_ids'].shape[0], 1), dtype=torch.long - ).to(inputs['input_ids'].device) - output = self.model( - **inputs, decoder_input_ids=decoder_input_ids, return_dict=True - ) - query_emb = output.last_hidden_state[:, 0, :] - - else: - output = self.model(**inputs, return_dict=True) - query_emb = pooling(output.pooler_output, - output.last_hidden_state, - inputs['attention_mask'], - self.pooling_method) - if "dpr" not in self.model_name.lower(): - query_emb = torch.nn.functional.normalize(query_emb, dim=-1) - - query_emb = query_emb.detach().cpu().numpy() - query_emb = query_emb.astype(np.float32, order="C") - return query_emb - - -class BaseRetriever: - """Base object for all retrievers.""" - - def __init__(self, config): - self.config = config - self.retrieval_method = config.retrieval_method - self.topk = config.retrieval_topk - - self.index_path = config.index_path - self.corpus_path = config.corpus_path - - # self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json') - - def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]: - r"""Retrieve topk relevant documents in corpus. - Return: - list: contains information related to the document, including: - contents: used for building index - title: (if provided) - text: (if provided) - """ - pass - - def _batch_search(self, query_list, num, return_score): - pass - - def search(self, *args, **kwargs): - return self._search(*args, **kwargs) - - def batch_search(self, *args, **kwargs): - return self._batch_search(*args, **kwargs) - - -class BM25Retriever(BaseRetriever): - r"""BM25 retriever based on pre-built pyserini index.""" - - def __init__(self, config): - super().__init__(config) - raise NotImplementedError - from pyserini.search.lucene import LuceneSearcher - self.searcher = LuceneSearcher(self.index_path) - self.contain_doc = self._check_contain_doc() - if not self.contain_doc: - self.corpus = load_corpus(self.corpus_path) - self.max_process_num = 8 - - def _check_contain_doc(self): - r"""Check if the index contains document content - """ - return self.searcher.doc(0).raw() is not None - - def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]: - if num is None: - num = self.topk - - hits = self.searcher.search(query, num) - if len(hits) < 1: - if return_score: - return [],[] - else: - return [] - - scores = [hit.score for hit in hits] - if len(hits) < num: - warnings.warn('Not enough documents retrieved!') - else: - hits = hits[:num] - - if self.contain_doc: - all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits] - results = [{'title': content.split("\n")[0].strip("\""), - 'text': "\n".join(content.split("\n")[1:]), - 'contents': content} for content in all_contents] - else: - results = load_docs(self.corpus, [hit.docid for hit in hits]) - - if return_score: - return results, scores - else: - return results - - def _batch_search(self, query_list, num: int = None, return_score = False): - # TODO: modify batch method - results = [] - scores = [] - for query in query_list: - item_result, item_score = self._search(query, num,True) - results.append(item_result) - scores.append(item_score) - - if return_score: - return results, scores - else: - return results - - -class DenseRetriever(BaseRetriever): - r"""Dense retriever based on pre-built faiss index.""" - - def __init__(self, config: dict, index): - super().__init__(config) - self.index = index - # self.index = faiss.read_index(self.index_path) - # if config.faiss_gpu: - # co = faiss.GpuMultipleClonerOptions() - # co.useFloat16 = True - # co.shard = True - # self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) - # # self.index = faiss.index_cpu_to_all_gpus(self.index) - - self.corpus = load_corpus(self.corpus_path) - self.encoder = Encoder( - model_name = self.retrieval_method, - model_path = config.retrieval_model_path, - pooling_method = config.retrieval_pooling_method, - max_length = config.retrieval_query_max_length, - use_fp16 = config.retrieval_use_fp16 - ) - self.topk = config.retrieval_topk - self.batch_size = self.config.retrieval_batch_size - - def _search(self, query: str, num: int = None, return_score = False): - raise NotImplementedError - if num is None: - num = self.topk - query_emb = self.encoder.encode(query) - scores, idxs = self.index.search(query_emb, k=num) - idxs = idxs[0] - scores = scores[0] - - results = load_docs(self.corpus, idxs) - if return_score: - return results, scores - else: - return results - - def _batch_search(self, query_list: List[str], num: int = None, return_score = False): - if isinstance(query_list, str): - query_list = [query_list] - if num is None: - num = self.topk - - batch_size = self.batch_size - - results = [] - scores = [] - - for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '): - query_batch = query_list[start_idx:start_idx + batch_size] - - # from time import time - # a = time() - batch_emb = self.encoder.encode(query_batch) - # b = time() - # print(f'################### encode time {b-a} #####################') - batch_scores, batch_idxs = ray.get(self.index.batch_search.remote(batch_emb, k=num)) - batch_scores = batch_scores.tolist() - batch_idxs = batch_idxs.tolist() - # print(f'################### search time {time()-b} #####################') - # exit() - - flat_idxs = sum(batch_idxs, []) - batch_results = load_docs(self.corpus, flat_idxs) - batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))] - - scores.extend(batch_scores) - results.extend(batch_results) - - if return_score: - return results, scores - else: - return results - -def get_retriever(config, index): - r"""Automatically select retriever class based on config's retrieval method - - Args: - config (dict): configuration with 'retrieval_method' key - - Returns: - Retriever: retriever instance - """ - if config.retrieval_method == "bm25": - raise NotImplementedError - return BM25Retriever(config) - else: - return DenseRetriever(config, index) - - - -class RetrieveWorker(Worker): - """Environment worker that handles GPU-based environment operations.""" - - def __init__(self, config, faiss_server): - super().__init__() - config.index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index') if config.retrieval_method != 'bm25' else os.path.join(config.index_path, 'bm25') - - self.config = config # Initialize environment later - self.faiss_server = faiss_server - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - self.retriever = get_retriever(self.config, self.faiss_server) - torch.cuda.empty_cache() - - @register(dispatch_mode=Dispatch.ALL_TO_ALL) - def batch_search(self, queries): - return self.retriever.batch_search(queries) - - -import ray -import faiss -import torch - -@ray.remote(num_gpus=8) # Allocate all GPUs -class FAISSIndexServer: - """Ray Actor that loads and serves a shared FAISS index with FAISS GPU optimization.""" - - def __init__(self, config): - """Initialize the FAISS index only once.""" - print("[FAISSIndexServer] Loading FAISS index...") - self.config = config - self.index = self.load_index(config) - - def load_index(self, config): - """Loads the FAISS index into GPU memory with sharding.""" - index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index') - index = faiss.read_index(index_path) - - if self.config.faiss_gpu: - - # Apply FAISS GPU settings - co = faiss.GpuMultipleClonerOptions() - co.useFloat16 = True # Reduce memory footprint - co.shard = True # Distribute index across all GPUs - - print("[FAISSIndexServer] Moving FAISS index to all GPUs with sharding enabled...") - index = faiss.index_cpu_to_all_gpus(index, co=co) - - print("[FAISSIndexServer] FAISS index successfully moved to GPUs.") - return index - - def batch_search(self, batch_emb, k): - """Perform batch search on the FAISS index.""" - print(f"[FAISSIndexServer] Received {len(batch_emb)} queries.") - return self.index.search(batch_emb, k) # Adjust 'k' as needed