# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ FSDP PPO Trainer with Ray-based single controller. This trainer supports model-agonistic model initialization with huggingface """ import os import uuid from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum from pprint import pprint from typing import Type, Dict import re import json from collections import defaultdict import numpy as np from codetiming import Timer from omegaconf import OmegaConf, open_dict from verl import DataProto from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance import re from search_r1.llm_agent.generation import LLMGenerationManager, GenerationConfig WorkerType = Type[Worker] class Role(Enum): """ To create more roles dynamically, you can subclass Role and add new members """ Actor = 0 Rollout = 1 ActorRollout = 2 Critic = 3 RefPolicy = 4 RewardModel = 5 ActorRolloutRef = 6 @dataclass class ResourcePoolManager: """ Define a resource pool specification. Resource pool will be initialized first. Mapping """ resource_pool_spec: dict[str, list[int]] mapping: dict[Role, str] resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) def create_resource_pool(self): for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name) self.resource_pool_dict[resource_pool_name] = resource_pool def get_resource_pool(self, role: Role) -> RayResourcePool: """Get the resource pool of the worker_cls""" return self.resource_pool_dict[self.mapping[role]] import torch from verl.utils.torch_functional import masked_mean def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'): responses = data.batch['responses'] response_length = responses.size(1) token_level_scores = data.batch['token_level_scores'] batch_size = data.batch.batch_size[0] attention_mask = data.batch['info_mask'] if 'info_mask' in data.batch else data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] # compute kl between ref_policy and current policy if 'ref_log_prob' in data.batch.keys(): kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], kl_penalty=kl_penalty) # (batch_size, response_length) kld = kld * response_mask beta = kl_ctrl.value else: beta = 0 kld = torch.zeros_like(response_mask, dtype=torch.float32) token_level_rewards = token_level_scores - beta * kld current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence current_kl = torch.mean(current_kl, dim=0).item() # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) data.batch['token_level_rewards'] = token_level_rewards metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta} return data, metrics def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): # prepare response group # TODO: add other ways to estimate advantages if adv_estimator == 'gae': values = data.batch['values'] responses = data.batch['responses'] response_length = responses.size(-1) attention_mask = data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] token_level_rewards = data.batch['token_level_rewards'] advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam) data.batch['advantages'] = advantages data.batch['returns'] = returns elif adv_estimator == 'grpo': token_level_rewards = data.batch['token_level_rewards'] index = data.non_tensor_batch['uid'] responses = data.batch['responses'] response_length = responses.size(-1) attention_mask = data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index) data.batch['advantages'] = advantages data.batch['returns'] = returns else: raise NotImplementedError return data def reduce_metrics(metrics: dict): for key, val in metrics.items(): metrics[key] = np.mean(val) return metrics def _compute_response_info(batch): response_length = batch.batch['responses'].shape[-1] prompt_mask = batch.batch['attention_mask'][:, :-response_length] response_mask = batch.batch['attention_mask'][:, -response_length:] prompt_length = prompt_mask.sum(-1).float() response_length = response_mask.sum(-1).float() # (batch_size,) return dict( response_mask=response_mask, prompt_length=prompt_length, response_length=response_length, ) def compute_data_metrics(batch, use_critic=True): # TODO: add response length sequence_score = batch.batch['token_level_scores'].sum(-1) sequence_reward = batch.batch['token_level_rewards'].sum(-1) advantages = batch.batch['advantages'] returns = batch.batch['returns'] max_response_length = batch.batch['responses'].shape[-1] prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] response_length = response_info['response_length'] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: values = batch.batch['values'] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) metrics = { # score 'critic/score/mean': torch.mean(sequence_score).detach().item(), 'critic/score/max': torch.max(sequence_score).detach().item(), 'critic/score/min': torch.min(sequence_score).detach().item(), # reward 'critic/rewards/mean': torch.mean(sequence_reward).detach().item(), 'critic/rewards/max': torch.max(sequence_reward).detach().item(), 'critic/rewards/min': torch.min(sequence_reward).detach().item(), # adv 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), 'critic/advantages/max': torch.max(valid_adv).detach().item(), 'critic/advantages/min': torch.min(valid_adv).detach().item(), # returns 'critic/returns/mean': torch.mean(valid_returns).detach().item(), 'critic/returns/max': torch.max(valid_returns).detach().item(), 'critic/returns/min': torch.min(valid_returns).detach().item(), **({ # values 'critic/values/mean': torch.mean(valid_values).detach().item(), 'critic/values/max': torch.max(valid_values).detach().item(), 'critic/values/min': torch.min(valid_values).detach().item(), # vf explained var 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } if use_critic else {}), # response length 'response_length/mean': torch.mean(response_length).detach().item(), 'response_length/max': torch.max(response_length).detach().item(), 'response_length/min': torch.min(response_length).detach().item(), 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), # prompt length 'prompt_length/mean': torch.mean(prompt_length).detach().item(), 'prompt_length/max': torch.max(prompt_length).detach().item(), 'prompt_length/min': torch.min(prompt_length).detach().item(), 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } # metrics for actions if 'turns_stats' in batch.meta_info: metrics['env/number_of_actions/mean'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()) metrics['env/number_of_actions/max'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()) metrics['env/number_of_actions/min'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()) if 'active_mask' in batch.meta_info: metrics['env/finish_ratio'] = 1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()) if 'valid_action_stats' in batch.meta_info: metrics['env/number_of_valid_action'] = float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()) metrics['env/ratio_of_valid_action'] = float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()) if 'valid_search_stats' in batch.meta_info: metrics['env/number_of_valid_search'] = float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()) return metrics def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) num_prompt_tokens = torch.sum(response_info['prompt_length']).item() num_response_tokens = torch.sum(response_info['response_length']).item() num_overall_tokens = num_prompt_tokens + num_response_tokens num_tokens_of_section = { 'gen': num_response_tokens, **{ name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor', 'rollout'] }, } return { **{ f'timing_s/{name}': value for name, value in timing_raw.items() }, **{ f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( )) & set(timing_raw.keys()) }, } @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: yield timing_raw[name] = timer.last class RayPPOTrainer(object): """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role def __init__(self, config, tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, reward_fn=None, val_reward_fn=None): # assert torch.cuda.is_available(), 'cuda must be available on driver' self.tokenizer = tokenizer self.config = config self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn self.hybrid_engine = config.actor_rollout_ref.hybrid_engine assert self.hybrid_engine, 'Currently, only support hybrid engine' if self.hybrid_engine: assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls # define KL control if self.use_reference_policy: if config.algorithm.kl_ctrl.type == 'fixed': self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef) elif config.algorithm.kl_ctrl.type == 'adaptive': assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}' self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef, target_kl=config.algorithm.kl_ctrl.target_kl, horizon=config.algorithm.kl_ctrl.horizon) else: raise NotImplementedError else: self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) self._create_dataloader() self._init_logger() def _init_logger(self): from verl.utils.tracking import Tracking self.logger = Tracking(project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True)) def _create_dataloader(self): from torch.utils.data import DataLoader # TODO: we have to make sure the batch size is divisible by the dp size from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files, tokenizer=self.tokenizer, prompt_key=self.config.data.prompt_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error') if self.config.data.train_data_num is not None: if self.config.data.train_data_num > len(self.train_dataset.dataframe): print(f"[WARNING] training dataset size is smaller than desired size. Using the dataset as the original size {len(self.train_dataset.dataframe)}") else: self.train_dataset.dataframe = self.train_dataset.dataframe.sample(self.config.data.train_data_num, random_state=42) print(f"filtered training dataset size: {len(self.train_dataset.dataframe)}") self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, shuffle=self.config.data.shuffle_train_dataloader, drop_last=True, collate_fn=collate_fn) self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files, tokenizer=self.tokenizer, prompt_key=self.config.data.prompt_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error') if self.config.data.val_data_num is not None: if self.config.data.val_data_num > len(self.val_dataset.dataframe): print(f"[WARNING] validation dataset size is smaller than desired size. Using the dataset as the original size {len(self.val_dataset.dataframe)}") else: self.val_dataset.dataframe = self.val_dataset.dataframe.sample(self.config.data.val_data_num, random_state=42) print(f"filtered validation dataset size: {len(self.val_dataset.dataframe)}") self.val_dataloader = DataLoader(dataset=self.val_dataset, batch_size=self.config.data.val_batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn) print(f'Size of train dataloader: {len(self.train_dataloader)}') print(f'Size of val dataloader: {len(self.val_dataloader)}') assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f'Total training steps: {self.total_training_steps}') OmegaConf.set_struct(self.config, True) with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps def _validate(self): """ The training loop of PPO with global metric computation. Accumulates metrics across all batches before computing final statistics. """ import torch reward_tensor_lst = [] data_source_lst = [] gen_config = GenerationConfig( max_turns=self.config.max_turns, max_start_length=self.config.data.max_start_length, max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.data.max_response_length, max_obs_length=self.config.data.max_obs_length, num_gpus=self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes, no_think_rl=self.config.algorithm.no_think_rl, search_url = self.config.retriever.url, topk = self.config.retriever.topk, ) # Agent config preparation generation_manager = LLMGenerationManager( tokenizer=self.tokenizer, actor_rollout_wg=self.actor_rollout_wg, config=gen_config, is_validation = True, ) if not self.config.do_search: for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) # we only do validation on rule-based rm if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': return {} test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) test_gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False, 'do_sample': False, 'validate': True, } # pad to be divisible by dp_size test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) print('validation generation end') test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function # for certain reward function (e.g. sandbox), the generation can overlap with reward reward_tensor = self.val_reward_fn(test_batch) reward_tensor_lst.append(reward_tensor) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) else: for batch_dict in self.val_dataloader: timing_raw = {} test_batch: DataProto = DataProto.from_single_dict(batch_dict) # test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True) test_gen_batch = test_batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) test_gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False, 'do_sample': False, 'validate': True, } with _timer('step', timing_raw): first_input_ids = test_gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone() with _timer('gen', timing_raw): generation_manager.timing_raw = timing_raw final_gen_batch_output = generation_manager.run_llm_loop( gen_batch=test_gen_batch, initial_input_ids=first_input_ids, ) test_batch = test_batch.union(final_gen_batch_output) for key in test_batch.batch.keys(): test_batch.batch[key] = test_batch.batch[key].long() # evaluate using reward_function # for certain reward function (e.g. sandbox), the generation can overlap with reward reward_tensor = self.val_reward_fn(test_batch) reward_tensor_lst.append(reward_tensor) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) reward_tensor = torch.cat([rw.sum(-1) for rw in reward_tensor_lst], dim=0).cpu() # (batch_size,) # reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) data_sources = np.concatenate(data_source_lst, axis=0) # evaluate test_score based on data source data_source_reward = {} for i in range(reward_tensor.shape[0]): data_source = data_sources[i] if data_source not in data_source_reward: data_source_reward[data_source] = [] data_source_reward[data_source].append(reward_tensor[i].item()) metric_dict = {} for data_source, rewards in data_source_reward.items(): metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) return metric_dict def init_workers(self): """Init resource pool and worker group""" self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} # create actor and rollout if self.hybrid_engine: resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, role='actor_rollout') self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls else: raise NotImplementedError # create critic if self.config.algorithm.adv_estimator == 'gae': resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls self.use_critic = True elif self.config.algorithm.adv_estimator == 'grpo': self.use_critic = False else: raise NotImplementedError # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role='ref') self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} self.wg_dicts = [] for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 self.wg_dicts.append(wg_dict) if self.use_critic: self.critic_wg = all_wg['critic'] self.critic_wg.init_model() if self.use_reference_policy: self.ref_policy_wg = all_wg['ref'] self.ref_policy_wg.init_model() if self.use_rm: self.rm_wg = all_wg['rm'] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory self.actor_rollout_wg = all_wg['actor_rollout'] self.actor_rollout_wg.init_model() def _save_checkpoint(self): actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', f'global_step_{self.global_steps}') actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, 'actor') self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) if self.use_critic: critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', f'global_step_{self.global_steps}') critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, 'critic') self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch['attention_mask'] batch_size = attention_mask.shape[0] global_seqlen_lst = attention_mask.view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True) # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix) metrics.update(global_balance_stats) def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ logger = self.logger self.global_steps = 0 # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): val_metrics = self._validate() pprint(f'Initial validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get('val_only', False): return # we start from step 1 self.global_steps += 1 # Agent config preparation gen_config = GenerationConfig( max_turns=self.config.max_turns, max_start_length=self.config.data.max_start_length, max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.data.max_response_length, max_obs_length=self.config.data.max_obs_length, num_gpus=self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes, no_think_rl=self.config.algorithm.no_think_rl, search_url = self.config.retriever.url, topk = self.config.retriever.topk, ) generation_manager = LLMGenerationManager( tokenizer=self.tokenizer, actor_rollout_wg=self.actor_rollout_wg, config=gen_config, ) # start training loop for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: print(f'epoch {epoch}, step {self.global_steps}') metrics = {} timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True) # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) #################### # original code here with _timer('step', timing_raw): if not self.config.do_search: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) #################### # Below is aLL about agents - the "LLM + forloop" #################### # with _timer('step', timing_raw): else: first_input_ids = gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone().long() with _timer('gen', timing_raw): generation_manager.timing_raw = timing_raw final_gen_batch_output = generation_manager.run_llm_loop( gen_batch=gen_batch, initial_input_ids=first_input_ids, ) # final_gen_batch_output.batch.apply(lambda x: x.long(), inplace=True) for key in final_gen_batch_output.batch.keys(): final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long() with torch.no_grad(): output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output) final_gen_batch_output = final_gen_batch_output.union(output) # batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], # dtype=object) batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy() # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(final_gen_batch_output) #################### #################### # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # batch.batch.apply(lambda x, key: x.long() if key != "old_log_probs" else x, inplace=True, key=True) for key in batch.batch.keys(): if key != 'old_log_probs': batch.batch[key] = batch.batch[key].long() if self.use_reference_policy: # compute reference log_prob with _timer('ref', timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: with _timer('values', timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) with _timer('adv', timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. if self.use_rm: # we first compute reward model score reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) # we combine with rule-based rm reward_tensor = self.reward_fn(batch) batch.batch['token_level_scores'] = reward_tensor # compute rewards. apply_kl_penalty if available if not self.config.actor_rollout_ref.actor.use_kl_loss: batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) else: batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] # compute advantages, executed on the driver process batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n) # update critic if self.use_critic: with _timer('update_critic', timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with _timer('update_actor', timing_raw): if self.config.do_search and self.config.actor_rollout_ref.actor.state_masking: batch, metrics = self._create_loss_mask(batch, metrics) actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) metrics.update(actor_output_metrics) # validate if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ self.global_steps % self.config.trainer.test_freq == 0: with _timer('testing', timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and \ self.global_steps % self.config.trainer.save_freq == 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) self.global_steps += 1 if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) return def _create_loss_mask(self, batch, metrics): """Create loss mask for state tokens.""" response_length = batch.batch['responses'].shape[-1] response_mask = batch.batch['attention_mask'][:, -response_length:] loss_mask = batch.batch['info_mask'][:, -response_length:] batch.batch['loss_mask'] = loss_mask metrics.update({ 'state_tokens/total': loss_mask.sum().item(), 'state_tokens/coverage': (loss_mask.sum() / response_mask.sum()).item(), }) return batch, metrics