diff --git a/README.md b/README.md
index 6dd0ed1..6bb5d54 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ You can refer to this [link](https://github.com/PeterGriffinJin/Search-R1/tree/m
- [Installation](#installation)
- [Quick start](#quick-start)
- [Preliminary results](#preliminary-results)
+- [Inference](#inference)
- [Use your own dataset](#use-your-own-dataset)
- [Use your own search engine](#use-your-own-search-engine)
- [Ackowledge](#acknowledge)
@@ -99,6 +100,20 @@ bash train_ppo.sh

+## Inference
+#### You can play with the trained Search-R1 model with your own question.
+(1) Launch a local retrieval server.
+```bash
+conda activate retriever
+bash retrieval_launch.sh
+```
+
+(2) Run inference.
+```bash
+conda activate searchr1
+python infer.py
+```
+You can modify the ```question``` on line 7 to something you're interested in.
## Use your own dataset
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000..5b93fa8
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,128 @@
+import transformers
+import torch
+import random
+from datasets import load_dataset
+import requests
+
+question = "Mike Barnett negotiated many contracts including which player that went on to become general manager of CSKA Moscow of the Kontinental Hockey League?"
+
+# Model ID and device setup
+model_id = "PeterJinGo/SearchR1-nq_hotpotqa_train-qwen2.5-7b-em-ppo"
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+question = question.strip()
+if question[-1] != '?':
+ question += '?'
+curr_eos = [151645, 151643] # for Qwen2.5 series models
+curr_search_template = '\n\n{output_text}{search_results}\n\n'
+
+# Prepare the message
+prompt = f"""Answer the given question. \
+You must conduct reasoning inside and first every time you get new information. \
+After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . \
+You can search as many times as your want. \
+If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: {question}\n"""
+
+# Initialize the tokenizer and model
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
+model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
+
+# Define the custom stopping criterion
+class StopOnSequence(transformers.StoppingCriteria):
+ def __init__(self, target_sequences, tokenizer):
+ # Encode the string so we have the exact token-IDs pattern
+ self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
+ self.target_lengths = [len(target_id) for target_id in self.target_ids]
+ self._tokenizer = tokenizer
+
+ def __call__(self, input_ids, scores, **kwargs):
+ # Make sure the target IDs are on the same device
+ targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
+
+ if input_ids.shape[1] < min(self.target_lengths):
+ return False
+
+ # Compare the tail of input_ids with our target_ids
+ for i, target in enumerate(targets):
+ if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
+ return True
+
+ return False
+
+def get_query(text):
+ import re
+ pattern = re.compile(r"(.*?)", re.DOTALL)
+ matches = pattern.findall(text)
+ if matches:
+ return matches[-1]
+ else:
+ return None
+
+def search(query: str):
+ payload = {
+ "queries": [query],
+ "topk": 3,
+ "return_scores": True
+ }
+ results = requests.post("http://127.0.0.1:8000/retrieve", json=payload).json()['result']
+
+ def _passages2string(retrieval_result):
+ format_reference = ''
+ for idx, doc_item in enumerate(retrieval_result):
+
+ content = doc_item['document']['contents']
+ title = content.split("\n")[0]
+ text = "\n".join(content.split("\n")[1:])
+ format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
+ return format_reference
+
+ return _passages2string(results[0])
+
+
+# Initialize the stopping criteria
+target_sequences = ["", " ", "\n", " \n", "\n\n", " \n\n"]
+stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
+
+cnt = 0
+
+if tokenizer.chat_template:
+ prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
+
+print('\n\n################# [Start Reasoning + Searching] ##################\n\n')
+print(prompt)
+# Encode the chat-formatted prompt and move it to the correct device
+while True:
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
+ attention_mask = torch.ones_like(input_ids)
+
+ # Generate text with the stopping criteria
+ outputs = model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=1024,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=tokenizer.eos_token_id,
+ do_sample=True,
+ temperature=0.7
+ )
+
+ if outputs[0][-1].item() in curr_eos:
+ generated_tokens = outputs[0][input_ids.shape[1]:]
+ output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
+ print(output_text)
+ break
+
+ generated_tokens = outputs[0][input_ids.shape[1]:]
+ output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
+
+ tmp_query = get_query(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ if tmp_query:
+ # print(f'searching "{tmp_query}"...')
+ search_results = search(tmp_query)
+ else:
+ search_results = ''
+
+ search_text = curr_search_template.format(output_text=output_text, search_results=search_results)
+ prompt += search_text
+ cnt += 1
+ print(search_text)