fix potential float bug
This commit is contained in:
@@ -5,11 +5,6 @@ import os
|
|||||||
from typing import List, Dict, Any, Tuple
|
from typing import List, Dict, Any, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .tensor_helper import TensorHelper, TensorConfig
|
from .tensor_helper import TensorHelper, TensorConfig
|
||||||
# from search_r1.utils import set_seed
|
|
||||||
# from search_r1.utils.plot import (
|
|
||||||
# save_trajectory_to_output,
|
|
||||||
# parse_llm_output
|
|
||||||
# )
|
|
||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.utils.tracking import Tracking
|
from verl.utils.tracking import Tracking
|
||||||
import shutil
|
import shutil
|
||||||
@@ -22,7 +17,6 @@ class GenerationConfig:
|
|||||||
max_prompt_length: int
|
max_prompt_length: int
|
||||||
max_response_length: int
|
max_response_length: int
|
||||||
max_obs_length: int
|
max_obs_length: int
|
||||||
# logging: dict
|
|
||||||
num_gpus: int
|
num_gpus: int
|
||||||
no_think_rl: bool=False
|
no_think_rl: bool=False
|
||||||
search_url: str = None
|
search_url: str = None
|
||||||
@@ -34,13 +28,11 @@ class LLMGenerationManager:
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
actor_rollout_wg,
|
actor_rollout_wg,
|
||||||
config: GenerationConfig,
|
config: GenerationConfig,
|
||||||
# logger: Tracking,
|
|
||||||
is_validation: bool = False,
|
is_validation: bool = False,
|
||||||
):
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.actor_rollout_wg = actor_rollout_wg
|
self.actor_rollout_wg = actor_rollout_wg
|
||||||
self.config = config
|
self.config = config
|
||||||
# self.logger = logger
|
|
||||||
self.is_validation = is_validation
|
self.is_validation = is_validation
|
||||||
|
|
||||||
self.tensor_fn = TensorHelper(TensorConfig(
|
self.tensor_fn = TensorHelper(TensorConfig(
|
||||||
@@ -188,6 +180,8 @@ class LLMGenerationManager:
|
|||||||
batch_size = active_batch.batch['input_ids'].shape[0]
|
batch_size = active_batch.batch['input_ids'].shape[0]
|
||||||
remainder = batch_size % num_gpus
|
remainder = batch_size % num_gpus
|
||||||
|
|
||||||
|
for key in active_batch.batch.keys():
|
||||||
|
active_batch.batch[key] = active_batch.batch[key].long()
|
||||||
if remainder == 0:
|
if remainder == 0:
|
||||||
return self.actor_rollout_wg.generate_sequences(active_batch)
|
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||||
|
|
||||||
@@ -201,6 +195,8 @@ class LLMGenerationManager:
|
|||||||
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
||||||
|
|
||||||
padded_active_batch = DataProto.from_dict(padded_batch)
|
padded_active_batch = DataProto.from_dict(padded_batch)
|
||||||
|
for key in padded_active_batch.batch.keys():
|
||||||
|
padded_active_batch.batch[key] = padded_active_batch.batch[key].long()
|
||||||
|
|
||||||
# Generate with padded batch
|
# Generate with padded batch
|
||||||
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user