diff --git a/search_r1/llm_agent/generation.py b/search_r1/llm_agent/generation.py index 90d73dc..6b68cb0 100644 --- a/search_r1/llm_agent/generation.py +++ b/search_r1/llm_agent/generation.py @@ -5,11 +5,6 @@ import os from typing import List, Dict, Any, Tuple from dataclasses import dataclass from .tensor_helper import TensorHelper, TensorConfig -# from search_r1.utils import set_seed -# from search_r1.utils.plot import ( -# save_trajectory_to_output, -# parse_llm_output -# ) from verl import DataProto from verl.utils.tracking import Tracking import shutil @@ -22,7 +17,6 @@ class GenerationConfig: max_prompt_length: int max_response_length: int max_obs_length: int - # logging: dict num_gpus: int no_think_rl: bool=False search_url: str = None @@ -34,13 +28,11 @@ class LLMGenerationManager: tokenizer, actor_rollout_wg, config: GenerationConfig, - # logger: Tracking, is_validation: bool = False, ): self.tokenizer = tokenizer self.actor_rollout_wg = actor_rollout_wg self.config = config - # self.logger = logger self.is_validation = is_validation self.tensor_fn = TensorHelper(TensorConfig( @@ -188,6 +180,8 @@ class LLMGenerationManager: batch_size = active_batch.batch['input_ids'].shape[0] remainder = batch_size % num_gpus + for key in active_batch.batch.keys(): + active_batch.batch[key] = active_batch.batch[key].long() if remainder == 0: return self.actor_rollout_wg.generate_sequences(active_batch) @@ -201,10 +195,12 @@ class LLMGenerationManager: padded_batch[k] = torch.cat([v, pad_sequence], dim=0) padded_active_batch = DataProto.from_dict(padded_batch) + for key in padded_active_batch.batch.keys(): + padded_active_batch.batch[key] = padded_active_batch.batch[key].long() # Generate with padded batch padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch) - + # Remove padding from output trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}