Initial commit
This commit is contained in:
0
search_r1/llm_agent/__init__.py
Normal file
0
search_r1/llm_agent/__init__.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import torch
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from .tensor_helper import TensorHelper, TensorConfig
|
||||
# from search_r1.utils import set_seed
|
||||
# from search_r1.utils.plot import (
|
||||
# save_trajectory_to_output,
|
||||
# parse_llm_output
|
||||
# )
|
||||
from verl import DataProto
|
||||
from verl.utils.tracking import Tracking
|
||||
import shutil
|
||||
import requests
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
max_turns: int
|
||||
max_start_length: int
|
||||
max_prompt_length: int
|
||||
max_response_length: int
|
||||
max_obs_length: int
|
||||
# logging: dict
|
||||
num_gpus: int
|
||||
no_think_rl: bool=False
|
||||
search_url: str = None
|
||||
topk: int = 3
|
||||
|
||||
class LLMGenerationManager:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
actor_rollout_wg,
|
||||
config: GenerationConfig,
|
||||
# logger: Tracking,
|
||||
is_validation: bool = False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_rollout_wg = actor_rollout_wg
|
||||
self.config = config
|
||||
# self.logger = logger
|
||||
self.is_validation = is_validation
|
||||
|
||||
self.tensor_fn = TensorHelper(TensorConfig(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_prompt_length=config.max_prompt_length,
|
||||
max_obs_length=config.max_obs_length,
|
||||
max_start_length=config.max_start_length
|
||||
))
|
||||
|
||||
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
|
||||
"""Tokenize a batch of responses."""
|
||||
return self.tokenizer(
|
||||
responses,
|
||||
add_special_tokens=False,
|
||||
return_tensors='pt',
|
||||
padding="longest"
|
||||
)['input_ids']
|
||||
|
||||
def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor:
|
||||
"""Process responses to stop at search operation or answer operation."""
|
||||
responses_str = self.tokenizer.batch_decode(
|
||||
responses,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
responses_str = [resp.split('</search>')[0] + '</search>'
|
||||
if '</search>' in resp
|
||||
else resp.split('</answer>')[0] + '</answer>'
|
||||
if '</answer>' in resp
|
||||
else resp
|
||||
for resp in responses_str]
|
||||
|
||||
if self.config.no_think_rl:
|
||||
raise ValueError('stop')
|
||||
# if no_think_rl is enabled, only keep action in the str
|
||||
actions, _ = self.env.postprocess_predictions(responses_str)
|
||||
responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
|
||||
print("RESPONSES:", responses_str)
|
||||
responses = self._batch_tokenize(responses_str)
|
||||
return responses, responses_str
|
||||
|
||||
def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
|
||||
"""Process next observations from environment."""
|
||||
|
||||
next_obs_ids = self.tokenizer(
|
||||
next_obs,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
add_special_tokens=False, # Prevents adding special tokens
|
||||
)['input_ids']
|
||||
|
||||
if next_obs_ids.shape[1] > self.config.max_obs_length:
|
||||
print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}")
|
||||
next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
|
||||
|
||||
return next_obs_ids
|
||||
|
||||
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor) -> Dict:
|
||||
"""Update rolling state with new responses and observations."""
|
||||
# Concatenate and handle padding
|
||||
new_input_ids = self.tensor_fn.concatenate_with_padding([
|
||||
rollings.batch['input_ids'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
])
|
||||
|
||||
# Create attention mask and position ids
|
||||
new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
|
||||
new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
|
||||
|
||||
# Cut to appropriate length
|
||||
effective_len = new_attention_mask.sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return DataProto.from_dict({
|
||||
'input_ids': new_input_ids[:, -max_len:],
|
||||
'position_ids': new_position_ids[:, -max_len:],
|
||||
'attention_mask': new_attention_mask[:, -max_len:]
|
||||
})
|
||||
|
||||
def _update_right_side(self, right_side: Dict,
|
||||
cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||
"""Update right side state."""
|
||||
if next_obs_ids != None:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
], pad_to_left=False)
|
||||
else:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
], pad_to_left=False)
|
||||
|
||||
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return {'responses': responses[:, :max_len]}
|
||||
|
||||
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||
"""
|
||||
Wrapper for generation that handles multi-GPU padding requirements.
|
||||
if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
if active_batch size is not divisible by num_gpus, pad with first sequence
|
||||
then remove padding from output
|
||||
"""
|
||||
num_gpus = self.config.num_gpus
|
||||
if num_gpus <= 1:
|
||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
|
||||
batch_size = active_batch.batch['input_ids'].shape[0]
|
||||
remainder = batch_size % num_gpus
|
||||
|
||||
if remainder == 0:
|
||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
|
||||
# Add padding sequences
|
||||
padding_size = num_gpus - remainder
|
||||
padded_batch = {}
|
||||
|
||||
for k, v in active_batch.batch.items():
|
||||
# Use first sequence as padding template
|
||||
pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
|
||||
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
||||
|
||||
padded_active_batch = DataProto.from_dict(padded_batch)
|
||||
|
||||
# Generate with padded batch
|
||||
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
||||
|
||||
# Remove padding from output
|
||||
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
|
||||
|
||||
# Handle meta_info if present
|
||||
if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
|
||||
trimmed_meta = {}
|
||||
for k, v in padded_output.meta_info.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
trimmed_meta[k] = v[:-padding_size]
|
||||
else:
|
||||
trimmed_meta[k] = v
|
||||
padded_output.meta_info = trimmed_meta
|
||||
|
||||
padded_output.batch = trimmed_batch
|
||||
return padded_output
|
||||
|
||||
def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:
|
||||
"""Run main LLM generation loop."""
|
||||
|
||||
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
rollings = gen_batch
|
||||
|
||||
# Main generation loop
|
||||
for step in range(self.config.max_turns):
|
||||
if not active_mask.sum():
|
||||
break
|
||||
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||
rollings.batch,
|
||||
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||
)
|
||||
|
||||
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||
rollings_active = DataProto.from_dict({
|
||||
k: v[active_mask] for k, v in rollings.batch.items()
|
||||
})
|
||||
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||
|
||||
meta_info = gen_output.meta_info
|
||||
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# Execute in environment and process observations
|
||||
next_obs, dones = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
|
||||
next_obs_ids = self._process_next_obs(next_obs)
|
||||
|
||||
# Update states
|
||||
rollings = self._update_rolling_state(
|
||||
rollings,
|
||||
responses_ids,
|
||||
next_obs_ids
|
||||
)
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
responses_ids,
|
||||
next_obs_ids
|
||||
)
|
||||
|
||||
# final LLM rollout
|
||||
if active_mask.sum():
|
||||
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||
rollings.batch,
|
||||
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||
)
|
||||
|
||||
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||
rollings_active = DataProto.from_dict({
|
||||
k: v[active_mask] for k, v in rollings.batch.items()
|
||||
})
|
||||
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||
|
||||
meta_info = gen_output.meta_info
|
||||
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# # Execute in environment and process observations
|
||||
_, dones = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
responses_ids,
|
||||
)
|
||||
|
||||
print("ACTIVE_TRAJ_NUM:", active_num_list)
|
||||
|
||||
return self._compose_final_output(original_left_side, original_right_side, meta_info)
|
||||
|
||||
def _compose_final_output(self, left_side: Dict,
|
||||
right_side: Dict,
|
||||
meta_info: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Compose final generation output."""
|
||||
final_output = right_side.copy()
|
||||
final_output['prompts'] = left_side['input_ids']
|
||||
|
||||
# Combine input IDs
|
||||
final_output['input_ids'] = torch.cat([
|
||||
left_side['input_ids'],
|
||||
right_side['responses']
|
||||
], dim=1)
|
||||
|
||||
# Create attention mask and position ids
|
||||
final_output['attention_mask'] = torch.cat([
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||
], dim=1)
|
||||
|
||||
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||
final_output['attention_mask']
|
||||
)
|
||||
|
||||
final_output = DataProto.from_dict(final_output)
|
||||
final_output.meta_info.update(meta_info)
|
||||
|
||||
return final_output
|
||||
|
||||
def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]:
|
||||
"""
|
||||
Execute predictions across multiple environments.
|
||||
NOTE: the function is the actual `step` function in the environment
|
||||
NOTE penalty_for_invalid is not included in observation shown to the LLM
|
||||
|
||||
Args:
|
||||
envs: List of environment instances
|
||||
predictions: List of action predictions
|
||||
pad_token: Token to use for padding
|
||||
|
||||
Returns:
|
||||
List of observation strings
|
||||
"""
|
||||
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||
next_obs, dones = [], []
|
||||
|
||||
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||
if do_search:
|
||||
search_results = self.batch_search(search_queries)
|
||||
assert len(search_results) == sum([1 for action in cur_actions if action == 'search'])
|
||||
else:
|
||||
search_results = [''] * sum([1 for action in cur_actions if action == 'search'])
|
||||
|
||||
for i, (action, active) in enumerate(zip(cur_actions, active_mask)):
|
||||
|
||||
if not active:
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
else:
|
||||
if action == 'answer':
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
elif action == 'search':
|
||||
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||
dones.append(0)
|
||||
else:
|
||||
next_obs.append(f'\nMy previous action is invalid. \
|
||||
If I want to search, I should put the query between <search> and </search>. \
|
||||
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||
dones.append(0)
|
||||
|
||||
assert len(search_results) == 0
|
||||
|
||||
return next_obs, dones
|
||||
|
||||
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||
"""
|
||||
Process (text-based) predictions from llm into actions and validity flags.
|
||||
|
||||
Args:
|
||||
predictions: List of raw predictions
|
||||
|
||||
Returns:
|
||||
Tuple of (actions list, validity flags list)
|
||||
"""
|
||||
actions = []
|
||||
contents = []
|
||||
|
||||
for prediction in predictions:
|
||||
if isinstance(prediction, str): # for llm output
|
||||
pattern = r'<(search|answer)>(.*?)</\1>'
|
||||
match = re.search(pattern, prediction, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(2).strip() # Return only the content inside the tags
|
||||
action = match.group(1)
|
||||
else:
|
||||
content = ''
|
||||
action = None
|
||||
else:
|
||||
raise ValueError(f"Invalid prediction type: {type(prediction)}")
|
||||
|
||||
actions.append(action)
|
||||
contents.append(content)
|
||||
|
||||
return actions, contents
|
||||
|
||||
def batch_search(self, queries: List[str] = None) -> str:
|
||||
"""
|
||||
Batchified search for queries.
|
||||
Args:
|
||||
queries: queries to call the search engine
|
||||
Returns:
|
||||
search results which is concatenated into a string
|
||||
"""
|
||||
results = self._batch_search(queries)['result']
|
||||
|
||||
return [self._passages2string(result) for result in results]
|
||||
|
||||
def _batch_search(self, queries):
|
||||
|
||||
payload = {
|
||||
"queries": queries,
|
||||
"topk": self.config.topk,
|
||||
"return_scores": True
|
||||
}
|
||||
|
||||
return requests.post(self.config.search_url, json=payload).json()
|
||||
|
||||
def _passages2string(self, retrieval_result):
|
||||
format_reference = ''
|
||||
for idx, doc_item in enumerate(retrieval_result):
|
||||
|
||||
content = doc_item['document']['contents']
|
||||
title = content.split("\n")[0]
|
||||
text = "\n".join(content.split("\n")[1:])
|
||||
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
|
||||
|
||||
return format_reference
|
||||
75
search_r1/llm_agent/tensor_helper.py
Normal file
75
search_r1/llm_agent/tensor_helper.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from typing import Dict, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class TensorConfig:
|
||||
pad_token_id: int
|
||||
max_prompt_length: int
|
||||
max_obs_length: int
|
||||
max_start_length: int
|
||||
|
||||
class TensorHelper:
|
||||
def __init__(self, config: TensorConfig):
|
||||
self.config = config
|
||||
|
||||
def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor],
|
||||
keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]:
|
||||
"""Cut tensors to their effective length based on attention mask."""
|
||||
effective_len = tensor_dict['attention_mask'].sum(dim=1).max()
|
||||
result = tensor_dict.copy()
|
||||
|
||||
for key in keys:
|
||||
if cut_left:
|
||||
result[key] = tensor_dict[key][:, -effective_len:]
|
||||
else:
|
||||
result[key] = tensor_dict[key][:, :effective_len]
|
||||
return result
|
||||
|
||||
def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert padding structure and return sorted tensor with indices."""
|
||||
mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id
|
||||
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||
return tensor.gather(1, sorted_indices), sorted_indices
|
||||
|
||||
def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from input ids."""
|
||||
return torch.where(input_ids != self.config.pad_token_id, 1, 0)
|
||||
|
||||
def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create position ids from attention mask."""
|
||||
return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask
|
||||
|
||||
def concatenate_with_padding(self, tensors: List[torch.Tensor],
|
||||
pad_to_left: bool = True) -> torch.Tensor:
|
||||
"""Concatenate tensors and handle padding."""
|
||||
concatenated = torch.cat(tensors, dim=1)
|
||||
padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left)
|
||||
return padded_tensor
|
||||
|
||||
def _example_level_pad(self, responses: torch.Tensor,
|
||||
responses_str: List[str],
|
||||
active_mask: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Pad responses for non-active examples with pad tokens.
|
||||
"""
|
||||
assert active_mask.sum() == responses.shape[0]
|
||||
# Create masked responses tensor
|
||||
batch_size = active_mask.shape[0]
|
||||
seq_len = responses.shape[1]
|
||||
padded_responses = torch.full(
|
||||
(batch_size, seq_len), self.config.pad_token_id,
|
||||
dtype=responses.dtype, device=responses.device
|
||||
)
|
||||
padded_responses[active_mask] = responses
|
||||
|
||||
# Create masked response strings
|
||||
padded_responses_str = [""] * batch_size
|
||||
|
||||
s = 0
|
||||
for i, is_active in enumerate(active_mask):
|
||||
if is_active:
|
||||
padded_responses_str[i] = responses_str[s]
|
||||
s += 1
|
||||
|
||||
return padded_responses, padded_responses_str
|
||||
Reference in New Issue
Block a user