fix potential float bug
This commit is contained in:
@@ -5,11 +5,6 @@ import os
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from .tensor_helper import TensorHelper, TensorConfig
|
||||
# from search_r1.utils import set_seed
|
||||
# from search_r1.utils.plot import (
|
||||
# save_trajectory_to_output,
|
||||
# parse_llm_output
|
||||
# )
|
||||
from verl import DataProto
|
||||
from verl.utils.tracking import Tracking
|
||||
import shutil
|
||||
@@ -22,7 +17,6 @@ class GenerationConfig:
|
||||
max_prompt_length: int
|
||||
max_response_length: int
|
||||
max_obs_length: int
|
||||
# logging: dict
|
||||
num_gpus: int
|
||||
no_think_rl: bool=False
|
||||
search_url: str = None
|
||||
@@ -34,13 +28,11 @@ class LLMGenerationManager:
|
||||
tokenizer,
|
||||
actor_rollout_wg,
|
||||
config: GenerationConfig,
|
||||
# logger: Tracking,
|
||||
is_validation: bool = False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_rollout_wg = actor_rollout_wg
|
||||
self.config = config
|
||||
# self.logger = logger
|
||||
self.is_validation = is_validation
|
||||
|
||||
self.tensor_fn = TensorHelper(TensorConfig(
|
||||
@@ -188,6 +180,8 @@ class LLMGenerationManager:
|
||||
batch_size = active_batch.batch['input_ids'].shape[0]
|
||||
remainder = batch_size % num_gpus
|
||||
|
||||
for key in active_batch.batch.keys():
|
||||
active_batch.batch[key] = active_batch.batch[key].long()
|
||||
if remainder == 0:
|
||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||
|
||||
@@ -201,10 +195,12 @@ class LLMGenerationManager:
|
||||
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
||||
|
||||
padded_active_batch = DataProto.from_dict(padded_batch)
|
||||
for key in padded_active_batch.batch.keys():
|
||||
padded_active_batch.batch[key] = padded_active_batch.batch[key].long()
|
||||
|
||||
# Generate with padded batch
|
||||
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
||||
|
||||
|
||||
# Remove padding from output
|
||||
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user