From 4936a3115ea9e5a8de14fe1a5069938177cc1d57 Mon Sep 17 00:00:00 2001 From: PeterGriffinJin Date: Fri, 21 Mar 2025 20:27:54 +0000 Subject: [PATCH] add code for inference --- README.md | 15 +++++++ infer.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 infer.py diff --git a/README.md b/README.md index 6dd0ed1..6bb5d54 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ You can refer to this [link](https://github.com/PeterGriffinJin/Search-R1/tree/m - [Installation](#installation) - [Quick start](#quick-start) - [Preliminary results](#preliminary-results) +- [Inference](#inference) - [Use your own dataset](#use-your-own-dataset) - [Use your own search engine](#use-your-own-search-engine) - [Ackowledge](#acknowledge) @@ -99,6 +100,20 @@ bash train_ppo.sh ![multi-turn](public/multi-turn.png) +## Inference +#### You can play with the trained Search-R1 model with your own question. +(1) Launch a local retrieval server. +```bash +conda activate retriever +bash retrieval_launch.sh +``` + +(2) Run inference. +```bash +conda activate searchr1 +python infer.py +``` +You can modify the ```question``` on line 7 to something you're interested in. ## Use your own dataset diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..5b93fa8 --- /dev/null +++ b/infer.py @@ -0,0 +1,128 @@ +import transformers +import torch +import random +from datasets import load_dataset +import requests + +question = "Mike Barnett negotiated many contracts including which player that went on to become general manager of CSKA Moscow of the Kontinental Hockey League?" + +# Model ID and device setup +model_id = "PeterJinGo/SearchR1-nq_hotpotqa_train-qwen2.5-7b-em-ppo" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +question = question.strip() +if question[-1] != '?': + question += '?' +curr_eos = [151645, 151643] # for Qwen2.5 series models +curr_search_template = '\n\n{output_text}{search_results}\n\n' + +# Prepare the message +prompt = f"""Answer the given question. \ +You must conduct reasoning inside and first every time you get new information. \ +After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . \ +You can search as many times as your want. \ +If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: {question}\n""" + +# Initialize the tokenizer and model +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") + +# Define the custom stopping criterion +class StopOnSequence(transformers.StoppingCriteria): + def __init__(self, target_sequences, tokenizer): + # Encode the string so we have the exact token-IDs pattern + self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences] + self.target_lengths = [len(target_id) for target_id in self.target_ids] + self._tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs): + # Make sure the target IDs are on the same device + targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids] + + if input_ids.shape[1] < min(self.target_lengths): + return False + + # Compare the tail of input_ids with our target_ids + for i, target in enumerate(targets): + if torch.equal(input_ids[0, -self.target_lengths[i]:], target): + return True + + return False + +def get_query(text): + import re + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + if matches: + return matches[-1] + else: + return None + +def search(query: str): + payload = { + "queries": [query], + "topk": 3, + "return_scores": True + } + results = requests.post("http://127.0.0.1:8000/retrieve", json=payload).json()['result'] + + def _passages2string(retrieval_result): + format_reference = '' + for idx, doc_item in enumerate(retrieval_result): + + content = doc_item['document']['contents'] + title = content.split("\n")[0] + text = "\n".join(content.split("\n")[1:]) + format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" + return format_reference + + return _passages2string(results[0]) + + +# Initialize the stopping criteria +target_sequences = ["", " ", "\n", " \n", "\n\n", " \n\n"] +stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)]) + +cnt = 0 + +if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) + +print('\n\n################# [Start Reasoning + Searching] ##################\n\n') +print(prompt) +# Encode the chat-formatted prompt and move it to the correct device +while True: + input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) + attention_mask = torch.ones_like(input_ids) + + # Generate text with the stopping criteria + outputs = model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=1024, + stopping_criteria=stopping_criteria, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + temperature=0.7 + ) + + if outputs[0][-1].item() in curr_eos: + generated_tokens = outputs[0][input_ids.shape[1]:] + output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + print(output_text) + break + + generated_tokens = outputs[0][input_ids.shape[1]:] + output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + + tmp_query = get_query(tokenizer.decode(outputs[0], skip_special_tokens=True)) + if tmp_query: + # print(f'searching "{tmp_query}"...') + search_results = search(tmp_query) + else: + search_results = '' + + search_text = curr_search_template.format(output_text=output_text, search_results=search_results) + prompt += search_text + cnt += 1 + print(search_text)