fix potential float bug

This commit is contained in:
PeterGriffinJin
2025-03-27 16:21:04 +00:00
parent f5204213d3
commit 95d16f4548

View File

@@ -5,11 +5,6 @@ import os
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
from .tensor_helper import TensorHelper, TensorConfig
# from search_r1.utils import set_seed
# from search_r1.utils.plot import (
# save_trajectory_to_output,
# parse_llm_output
# )
from verl import DataProto
from verl.utils.tracking import Tracking
import shutil
@@ -22,7 +17,6 @@ class GenerationConfig:
max_prompt_length: int
max_response_length: int
max_obs_length: int
# logging: dict
num_gpus: int
no_think_rl: bool=False
search_url: str = None
@@ -34,13 +28,11 @@ class LLMGenerationManager:
tokenizer,
actor_rollout_wg,
config: GenerationConfig,
# logger: Tracking,
is_validation: bool = False,
):
self.tokenizer = tokenizer
self.actor_rollout_wg = actor_rollout_wg
self.config = config
# self.logger = logger
self.is_validation = is_validation
self.tensor_fn = TensorHelper(TensorConfig(
@@ -188,6 +180,8 @@ class LLMGenerationManager:
batch_size = active_batch.batch['input_ids'].shape[0]
remainder = batch_size % num_gpus
for key in active_batch.batch.keys():
active_batch.batch[key] = active_batch.batch[key].long()
if remainder == 0:
return self.actor_rollout_wg.generate_sequences(active_batch)
@@ -201,10 +195,12 @@ class LLMGenerationManager:
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
padded_active_batch = DataProto.from_dict(padded_batch)
for key in padded_active_batch.batch.keys():
padded_active_batch.batch[key] = padded_active_batch.batch[key].long()
# Generate with padded batch
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
# Remove padding from output
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}