Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

13
verl/trainer/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,6 @@
data:
path: /tmp/math_Qwen2-7B-Instruct.parquet
prompt_key: prompt
response_key: responses
data_source_key: data_source
reward_model_key: reward_model

View File

@@ -0,0 +1,35 @@
trainer:
nnodes: 1
n_gpus_per_node: 8
data:
path: ~/data/rlhf/math/test.parquet
prompt_key: prompt
n_samples: 5
output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
batch_size: 128
model:
path: ~/models/Qwen2-7B-Instruct
external_lib: null
rollout:
name: vllm
temperature: 1.0
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.7
prompt_length: 1536
response_length: 512
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
micro_batch_size: 256
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 1
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 8
# for hf rollout
do_sample: True

View File

@@ -0,0 +1,148 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: {}
enable_gradient_checkpointing: False
actor:
strategy: megatron # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
clip_ratio: 0.2
entropy_coeff: 0.001
ppo_epochs: 1
shuffle: True
optim:
lr: 1e-6
clip_grad: 1.0
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
ref:
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
param_offload: False
log_prob_micro_batch_size: 32
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # for xperf_gpt
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_megatron
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 2
# for hf rollout
do_sample: True
layer_name_map:
qkv_layer_name: qkv
gate_proj_layer_name: gate_up
# number of responses (i.e. num sample times)
n: 1
critic:
strategy: megatron
optim:
lr: 1e-5
clip_grad: 1.0
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 2
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
kl_ctrl:
type: fixed
kl_coef: 0.001
reward_model:
enable: False
strategy: megatron
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
load_weight: True
param_offload: False
micro_batch_size: 64
max_length: null
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'wandb']
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: 2
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

View File

@@ -0,0 +1,177 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
train_data_num: null
val_data_num: null
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
max_start_length: 256
max_obs_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle_train_dataloader: True
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
state_masking: False
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
grad_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
log_prob_micro_batch_size: 128
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.95
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 128
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
n_agent: 1 # different here used for agent tasks only
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
micro_batch_size: 64
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
retriever:
url: "http://127.0.0.1:8000/retrieve"
topk: 3
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
no_think_rl: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
state_masking:
start_state_marker: "<information>"
end_state_marker: "</information>"
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: -1
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_turns: 10
do_search: true

View File

@@ -0,0 +1,42 @@
data:
train_batch_size: 256
micro_batch_size: 16 # this is also val batch size
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
prompt_key: question
response_key: answer
max_length: 1024
truncation: error
balance_dp_token: False
chat_template: null
model:
partial_pretrain: ~/models/gemma-1.1-7b-it
fsdp_config:
wrap_policy:
min_num_params: 0
cpu_offload: False
offload_params: False
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: False
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0
trainer:
default_local_dir: /tmp/sft_model
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
resume_path: null
project_name: gsm8k-sft
experiment_name: test
total_epochs: 4
total_training_steps: null
validate_before_training: False
logger: ['console']
seed: 1

View File

@@ -0,0 +1,435 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A lightweight one-file FSDP SFT Trainer
TODO(zhangchi.usc1992)
- Add calculation of mfu
- Add validation
"""
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import logging
import re
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
from tensordict import TensorDict
from torch.utils.data import DataLoader, DistributedSampler
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.dataset import SFTDataset
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.tracking import Tracking
from torch.distributed.device_mesh import DeviceMesh
import verl.utils.hdfs_io as hdfs_io
from verl.utils.debug import log_gpu_memory_usage
from peft import LoraConfig, TaskType, get_peft_model
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
def extract_step(path):
match = re.search(r'global_step_(\d+)', path)
if match:
return int(match.group(1))
return None
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class FSDPSFTTrainer(object):
def __init__(self, config, device_mesh: DeviceMesh):
self.config = config
self.device_mesh = device_mesh
# build tokenizer first
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
if self.config.data.chat_template is not None:
raise ValueError('Apply Chat template from config is not supported yet.')
# normalize dp size
self._normalize_config_bsz()
self._build_dataloader()
# build model
self._build_model_optimizer()
# TODO: add checkpoint manager
if self.device_mesh.get_rank() == 0:
print(self.config)
def _normalize_config_bsz(self):
dp_size = self.device_mesh.size()
if self.device_mesh.get_rank() == 0:
print(f'Normalize batch size by dp {dp_size}')
assert self.config.data.train_batch_size % dp_size == 0
assert self.config.data.micro_batch_size % dp_size == 0
self.config.data.train_batch_size //= dp_size
self.config.data.micro_batch_size //= dp_size
def _build_dataloader(self):
config = self.config
# build dataset
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)
# build dataloader
rank = self.device_mesh.get_rank()
world_size = self.device_mesh.size()
self.train_sampler = DistributedSampler(self.train_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
self.val_sampler = DistributedSampler(self.val_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=config.data.micro_batch_size,
sampler=self.val_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
def _build_model_optimizer(self):
# TODO (zhangchi.usc1992):
# 1. support pretrain from random weights
# 2. support init directly from sharded weights
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
log_gpu_memory_usage('Before model allocation', logger=logger)
trust_remote_code = self.config.model.trust_remote_code
# load config first
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
# This may be very large
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if self.config.model.get('lora_rank', 0) > 0:
self.model.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
if self.config.model.enable_gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
log_gpu_memory_usage('After model allocation', logger=logger)
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get('lora_rank', 0) > 0)
if self.device_mesh.get_rank() == 0:
print(auto_wrap_policy)
if not self.config.model.fsdp_config.cpu_offload:
cpu_offload = None
else:
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
self.fsdp_model = FSDP(module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False)
log_gpu_memory_usage('After FSDP wrapping', logger=logger)
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay)
log_gpu_memory_usage('After initialize optimizer', logger=logger)
steps_per_epoch = len(self.train_dataloader)
total_steps = steps_per_epoch * self.config.trainer.total_epochs
if self.device_mesh.get_rank() == 0:
print(
f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}'
)
num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio)
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps)
def _compute_loss(self, batch):
loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda()
labels = batch['input_ids'][:, 1:].cuda()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.fsdp_model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
position_ids=batch['position_ids'],
use_cache=False) # prevent model thinks it it generating
logits = output.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction='none')
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = loss * loss_mask
valid_token_this_rank = torch.sum(loss_mask)
if self.config.data.balance_dp_token:
torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks
dp_size = torch.distributed.get_world_size()
else:
dp_size = 1
loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp
return loss
def training_step(self, batch: TensorDict):
self.fsdp_model.train()
log_gpu_memory_usage('Before optimizer zero_grad', logger=logger)
self.optimizer.zero_grad()
log_gpu_memory_usage('After optimizer zero_grad', logger=logger)
micro_batches = batch.split(self.config.data.micro_batch_size)
n_micro_batches = len(micro_batches)
step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
loss.backward()
step_loss += loss.item()
self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
log_gpu_memory_usage('Before optimizer step', logger=logger)
self.optimizer.step()
log_gpu_memory_usage('After optimizer step', logger=logger)
self.lr_scheduler.step()
# reduce loss across dp ranks
lr = self.lr_scheduler.get_last_lr()[0]
log_gpu_memory_usage('After offload weights', logger=logger)
step_loss = torch.tensor(step_loss).cuda()
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
def validation_step(self, batch: TensorDict):
self.fsdp_model.eval()
with torch.no_grad():
loss = self._compute_loss(batch)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
return loss
def save_checkpoint(self, step):
# save checkpoint
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.fsdp_model.state_dict()
path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
# save huggingface model
if self.device_mesh.get_rank() == 0:
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
self.tokenizer.save_pretrained(path)
if self.config.trainer.default_hdfs_dir:
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
torch.distributed.barrier()
def fit(self):
rank = self.device_mesh.get_rank()
# TODO: add a unified tracking
if rank == 0:
tracking = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger)
global_step = 0
# compute the total training steps.
# the total training steps in SFT is mainly for early exit
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
# TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow.
if self.config.trainer.validate_before_training:
# validate before training
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
for epoch in range(self.config.trainer.total_epochs):
self.train_sampler.set_epoch(epoch=epoch)
for data in self.train_dataloader:
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
metric = self.training_step(data)
if rank == 0:
tracking.log(data=metric, step=global_step)
global_step += 1
# for early exit validation
if global_step >= self.total_training_steps:
# Perform final validation
val_losses = []
for val_data in self.val_dataloader:
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(val_data)
val_losses.append(val_loss)
if rank == 0:
avg_val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': avg_val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# Save final checkpoint
self.save_checkpoint(step=global_step)
return
# validation
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# save checkpoint
self.save_checkpoint(step=global_step)
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
import hydra
from torch.distributed.device_mesh import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
def main(config):
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
trainer.fit()
if __name__ == '__main__':
main()

69
verl/trainer/main_eval.py Normal file
View File

@@ -0,0 +1,69 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
"""
import hydra
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.reward_score import math, gsm8k
import pandas as pd
import numpy as np
def select_reward_fn(data_source):
if data_source == 'lighteval/MATH':
return math.compute_score
else:
raise NotImplementedError
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
def main(config):
local_path = copy_local_path_from_hdfs(config.data.path)
dataset = pd.read_parquet(local_path)
prompts = dataset[config.data.prompt_key]
responses = dataset[config.data.response_key]
data_sources = dataset[config.data.data_source_key]
reward_model_data = dataset[config.data.reward_model_key]
passes = 0
total = len(dataset)
for i in range(total):
response_lst = responses[i]
data_source = data_sources[i]
# select reward score based on data_source
prompt = prompts[i]
reward_data = reward_model_data[i]
reward_fn = select_reward_fn(data_source)
ground_truth = reward_data['ground_truth']
score_lst = []
for r in response_lst:
score = reward_fn(r, ground_truth)
score_lst.append(score)
max_score = np.max(score_lst)
if max_score == 1:
passes += 1
print(f'pass@5: {passes / total}')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,137 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""
import ray
import numpy as np
import hydra
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['TORCH_COMPILE_DISABLE'] = '1'
from verl.utils.model import compute_position_id_with_mask
import pandas as pd
from transformers import AutoTokenizer
from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@hydra.main(config_path='config', config_name='generation', version_base=None)
def main(config):
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
local_path = copy_local_path_from_hdfs(config.model.path)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
if config.rollout.temperature == 0.:
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
dataset = pd.read_parquet(config.data.path)
chat_lst = dataset[config.data.prompt_key].tolist()
chat_lst = [chat.tolist() for chat in chat_lst]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout')
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model()
total_samples = len(dataset)
# real_batch_size = data.batch['input_ids'].shape[0]
config_batch_size = config.data.batch_size
dp_size = wg.world_size // config.rollout.tensor_model_parallel_size
num_batch = (total_samples // config_batch_size) + 1
output_lst = [[] for _ in range(config.data.n_samples)]
for batch_idx in range(num_batch):
print(f'[{batch_idx+1}/{num_batch}] Start to process.')
batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size]
inputs = tokenizer.apply_chat_template(batch_chat_lst,
add_generation_prompt=True,
padding=True,
truncation=True,
max_length=config.rollout.prompt_length,
return_tensors='pt',
return_dict=True,
tokenize=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
position_ids = compute_position_id_with_mask(attention_mask)
batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
data = DataProto.from_dict(batch_dict)
real_batch_size = data.batch['input_ids'].shape[0]
if real_batch_size % dp_size != 0:
dummy_data_size = dp_size - real_batch_size % dp_size
dummy_data = data[:dummy_data_size]
data = DataProto.concat([data, dummy_data])
print(
f'dp_size {dp_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data'
)
batch_size = data.batch['input_ids'].shape[0]
assert batch_size % dp_size == 0, f'batch_size {batch_size} is not divisible by dp_size {dp_size}'
print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
# START TO GENERATE FOR n_samples TIMES
for i in range(config.data.n_samples):
output = wg.generate_sequences(data)
# remove dummy data
output = output[:real_batch_size]
output_text = tokenizer.batch_decode(output.batch['input_ids'][:, -config.rollout.response_length:],
skip_special_tokens=False)
# remove the padding
pad_token = tokenizer.pad_token
output_text_unpad = []
for text in output_text:
output_text_unpad.append(text.replace(pad_token, ''))
output_lst[i].extend(output_text_unpad)
# convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
output_lst = np.array(output_lst, dtype=object)
output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()
# add to the data frame
dataset[f'responses'] = output_lst
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
dataset.to_parquet(config.data.output_path)
return output_text
if __name__ == '__main__':
main()

202
verl/trainer/main_ppo.py Normal file
View File

@@ -0,0 +1,202 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl import DataProto
import torch
from verl.utils.reward_score import qa_em
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
import re
import numpy as np
def _select_rm_score_fn(data_source):
if "nq" in data_source:
return qa_em.compute_score_em
else:
raise NotImplementedError
class RewardManager():
"""The reward manager.
"""
def __init__(self, tokenizer, num_examine, format_score=0.) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.format_score = format_score
def __call__(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if 'rm_scores' in data.batch.keys():
return data.batch['rm_scores']
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
# all_scores = []
already_print_data_sources = {}
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
# select rm_score
data_source = data_item.non_tensor_batch['data_source']
compute_score_fn = _select_rm_score_fn(data_source)
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, format_score=self.format_score)
reward_tensor[i, valid_response_length - 1] = score
# all_scores.append(score)
if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print(sequences_str)
# print(f"[DEBUG] all_scores: {all_scores}")
# print(f"[DEBUG] all_scores shape: {np.array(all_scores).shape}")
# print(f"[DEBUG] all_scores mean: {np.mean(all_scores)}")
# print(f"[DEBUG] all_scores max: {np.max(all_scores)}")
# print(f"[DEBUG] all_scores min: {np.min(all_scores)}")
# print(f"[DEBUG] all_scores std: {np.std(all_scores)}")
return reward_tensor
import ray
import hydra
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config))
@ray.remote
def main_task(config):
from verl.utils.fs import copy_local_path_from_hdfs
from transformers import AutoTokenizer
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# env_class = ENV_CLASS_MAPPING[config.env.name]
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
# Note that we always use function-based RM for validation
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,274 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""
import numpy as np
import torch
from collections import defaultdict
import verl.utils.torch_functional as verl_F
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
pass
def get_kl_controller(config): # seems never used?
if config.critic.kl_ctrl.type == 'fixed':
kl_ctrl = FixedKLController(kl_coef=config.critic.kl_ctrl.kl_coef)
elif config.critic.kl_ctrl.type == 'adaptive':
assert config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
kl_ctrl = AdaptiveKLController(init_kl_coef=config.critic.kl_ctrl.kl_coef,
target_kl=config.critic.kl_ctrl.target_kl,
horizon=config.critic.kl_ctrl.horizon)
else:
raise ValueError('Unknown kl_ctrl type')
return kl_ctrl
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
gamma: torch.Tensor, lam: torch.Tensor):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
non_zero_mask = (token_level_rewards != 0)
scores = (token_level_rewards * non_zero_mask).sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob: `(torch.Tensor)`
shape: (bs, response_length)
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped
"""
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
Returns:
vf_loss: a scalar (`torch.FloatTensor`):
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns)**2
vf_losses2 = (vpredclipped - returns)**2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
Args:
logprob:
ref_logprob:
Returns:
"""
if kl_penalty == "kl":
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty == 'low_var_kl':
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError

View File

@@ -0,0 +1,920 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import os
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Type, Dict
import re
import json
from collections import defaultdict
import numpy as np
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
import re
from search_r1.llm_agent.generation import LLMGenerationManager, GenerationConfig
WorkerType = Type[Worker]
class Role(Enum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
@dataclass
class ResourcePoolManager:
"""
Define a resource pool specification. Resource pool will be initialized first.
Mapping
"""
resource_pool_spec: dict[str, list[int]]
mapping: dict[Role, str]
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
def create_resource_pool(self):
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
use_gpu=True,
max_colocate_count=1,
name_prefix=resource_pool_name)
self.resource_pool_dict[resource_pool_name] = resource_pool
def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
return self.resource_pool_dict[self.mapping[role]]
import torch
from verl.utils.torch_functional import masked_mean
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'):
responses = data.batch['responses']
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
if 'ref_log_prob' in data.batch.keys():
kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
kl_penalty=kl_penalty) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
else:
beta = 0
kld = torch.zeros_like(response_mask, dtype=torch.float32)
token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch['token_level_rewards'] = token_level_rewards
metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta}
return data, metrics
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == 'gae':
values = data.batch['values']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch['token_level_rewards']
advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards,
values=values,
eos_mask=response_mask,
gamma=gamma,
lam=lam)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'grpo':
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards,
eos_mask=response_mask,
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else:
raise NotImplementedError
return data
def reduce_metrics(metrics: dict):
for key, val in metrics.items():
metrics[key] = np.mean(val)
return metrics
def _compute_response_info(batch):
response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)
def compute_data_metrics(batch, use_critic=True):
# TODO: add response length
sequence_score = batch.batch['token_level_scores'].sum(-1)
sequence_reward = batch.batch['token_level_rewards'].sum(-1)
advantages = batch.batch['advantages']
returns = batch.batch['returns']
max_response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch['values']
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
'critic/score/mean':
torch.mean(sequence_score).detach().item(),
'critic/score/max':
torch.max(sequence_score).detach().item(),
'critic/score/min':
torch.min(sequence_score).detach().item(),
# reward
'critic/rewards/mean':
torch.mean(sequence_reward).detach().item(),
'critic/rewards/max':
torch.max(sequence_reward).detach().item(),
'critic/rewards/min':
torch.min(sequence_reward).detach().item(),
# adv
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
**({
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),
# response length
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
# prompt length
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# metrics for actions
# 'metric/total_env':
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
# 'metric/finished_env':
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
# 'metric/traj_length':
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
# 'metric/valid_action':
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
# 'metric/effective_action':
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
# 'metric/effective_action_ratio':
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
}
return metrics
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
num_response_tokens = torch.sum(response_info['response_length']).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
'gen': num_response_tokens,
**{
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor', 'rollout']
},
}
return {
**{
f'timing_s/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer(object):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
self.tokenizer = tokenizer
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert self.hybrid_engine, 'Currently, only support hybrid engine'
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}'
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
# define KL control
if self.use_reference_policy:
if config.algorithm.kl_ctrl.type == 'fixed':
self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef)
elif config.algorithm.kl_ctrl.type == 'adaptive':
assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef,
target_kl=config.algorithm.kl_ctrl.target_kl,
horizon=config.algorithm.kl_ctrl.horizon)
else:
raise NotImplementedError
else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.)
self._create_dataloader()
self._init_logger()
def _init_logger(self):
from verl.utils.tracking import Tracking
self.logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
def _create_dataloader(self):
from torch.utils.data import DataLoader
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
if self.config.data.train_data_num is not None:
if self.config.data.train_data_num > len(self.train_dataset.dataframe):
print(f"[WARNING] training dataset size is smaller than desired size. Using the dataset as the original size {len(self.train_dataset.dataframe)}")
else:
self.train_dataset.dataframe = self.train_dataset.dataframe.sample(self.config.data.train_data_num, random_state=42)
print(f"filtered training dataset size: {len(self.train_dataset.dataframe)}")
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
shuffle=self.config.data.shuffle_train_dataloader,
drop_last=True,
collate_fn=collate_fn)
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
tokenizer=self.tokenizer,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
if self.config.data.val_data_num is not None:
if self.config.data.val_data_num > len(self.val_dataset.dataframe):
print(f"[WARNING] validation dataset size is smaller than desired size. Using the dataset as the original size {len(self.val_dataset.dataframe)}")
else:
self.val_dataset.dataframe = self.val_dataset.dataframe.sample(self.config.data.val_data_num, random_state=42)
print(f"filtered validation dataset size: {len(self.val_dataset.dataframe)}")
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=self.config.data.val_batch_size,
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
print(f'Size of train dataloader: {len(self.train_dataloader)}')
print(f'Size of val dataloader: {len(self.val_dataloader)}')
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
def _validate(self):
"""
The training loop of PPO with global metric computation.
Accumulates metrics across all batches before computing final statistics.
"""
import torch
reward_tensor_lst = []
data_source_lst = []
gen_config = GenerationConfig(
max_turns=self.config.max_turns,
max_start_length=self.config.data.max_start_length,
max_prompt_length=self.config.data.max_prompt_length,
max_response_length=self.config.data.max_response_length,
max_obs_length=self.config.data.max_obs_length,
num_gpus=self.config.trainer.n_gpus_per_node,
no_think_rl=self.config.algorithm.no_think_rl,
search_url = self.config.retriever.url,
topk = self.config.retriever.topk,
)
# Agent config preparation
generation_manager = LLMGenerationManager(
tokenizer=self.tokenizer,
actor_rollout_wg=self.actor_rollout_wg,
config=gen_config,
is_validation = True,
)
if not self.config.do_search:
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
return {}
test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'recompute_log_prob': False,
'do_sample': False,
'validate': True,
}
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print('validation generation end')
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
# for certain reward function (e.g. sandbox), the generation can overlap with reward
reward_tensor = self.val_reward_fn(test_batch)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
else:
for batch_dict in self.val_dataloader:
timing_raw = {}
test_batch: DataProto = DataProto.from_single_dict(batch_dict)
# test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True)
test_gen_batch = test_batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'recompute_log_prob': False,
'do_sample': False,
'validate': True,
}
with _timer('step', timing_raw):
first_input_ids = test_gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone()
with _timer('gen', timing_raw):
generation_manager.timing_raw = timing_raw
final_gen_batch_output = generation_manager.run_llm_loop(
gen_batch=test_gen_batch,
initial_input_ids=first_input_ids,
)
test_batch = test_batch.union(final_gen_batch_output)
for key in test_batch.batch.keys():
test_batch.batch[key] = test_batch.batch[key].long()
# evaluate using reward_function
# for certain reward function (e.g. sandbox), the generation can overlap with reward
try:
reward_tensor = self.val_reward_fn(test_batch)
except:
print(test_batch)
exit()
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
reward_tensor = torch.cat([rw.sum(-1) for rw in reward_tensor_lst], dim=0).cpu() # (batch_size,)
# reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
data_sources = np.concatenate(data_source_lst, axis=0)
# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)
return metric_dict
def init_workers(self):
"""Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role='actor_rollout')
self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.config.algorithm.adv_estimator == 'gae':
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls
self.use_critic = True
elif self.config.algorithm.adv_estimator == 'grpo':
self.use_critic = False
else:
raise NotImplementedError
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role='ref')
self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg = all_wg['critic']
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg = all_wg['ref']
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg['rm']
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()
def _save_checkpoint(self):
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)
if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch['attention_mask']
batch_size = attention_mask.shape[0]
global_seqlen_lst = attention_mask.view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst,
k_partitions=world_size,
equal_size=True)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst,
partitions=global_partition_lst,
prefix=logging_prefix)
metrics.update(global_balance_stats)
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
logger = self.logger
self.global_steps = 0
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return
# we start from step 1
self.global_steps += 1
# Agent config preparation
gen_config = GenerationConfig(
max_turns=self.config.max_turns,
max_start_length=self.config.data.max_start_length,
max_prompt_length=self.config.data.max_prompt_length,
max_response_length=self.config.data.max_response_length,
max_obs_length=self.config.data.max_obs_length,
num_gpus=self.config.trainer.n_gpus_per_node,
no_think_rl=self.config.algorithm.no_think_rl,
search_url = self.config.retriever.url,
topk = self.config.retriever.topk,
)
generation_manager = LLMGenerationManager(
tokenizer=self.tokenizer,
actor_rollout_wg=self.actor_rollout_wg,
config=gen_config,
)
# start training loop
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
print(f'epoch {epoch}, step {self.global_steps}')
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
####################
# original code here
with _timer('step', timing_raw):
if not self.config.do_search:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
####################
# Below is aLL about agents - the "LLM + forloop"
####################
# with _timer('step', timing_raw):
else:
first_input_ids = gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone().long()
with _timer('gen', timing_raw):
generation_manager.timing_raw = timing_raw
final_gen_batch_output = generation_manager.run_llm_loop(
gen_batch=gen_batch,
initial_input_ids=first_input_ids,
)
# final_gen_batch_output.batch.apply(lambda x: x.long(), inplace=True)
for key in final_gen_batch_output.batch.keys():
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
with torch.no_grad():
try:
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
final_gen_batch_output = final_gen_batch_output.union(output)
except:
print('############### here ###################')
print(final_gen_batch_output)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(final_gen_batch_output)
####################
####################
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
# batch.batch.apply(lambda x, key: x.long() if key != "old_log_probs" else x, inplace=True, key=True)
for key in batch.batch.keys():
if key != 'old_log_probs':
batch.batch[key] = batch.batch[key].long()
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
try:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
except:
print('################## herehere ################')
print(batch)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.use_kl_loss:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
if self.config.do_search and self.config.actor_rollout_ref.actor.state_masking:
batch, metrics = self._create_loss_mask(batch, metrics)
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
return
def _create_loss_mask(self, batch, metrics):
"""Create loss mask for state tokens."""
response_length = batch.batch['responses'].shape[-1]
response_mask = batch.batch['attention_mask'][:, -response_length:]
# Initialize state mask
state_mask = torch.ones_like(response_mask)
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
for i, response in enumerate(responses):
# Find all pairs of start and end marker positions
start_marker = self.config.algorithm.state_masking.start_state_marker
end_marker = self.config.algorithm.state_masking.end_state_marker
# Get all start and end positions
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
# Convert character positions to token positions
for start, end in zip(start_positions, end_positions):
prefix_to_start = response[:start]
state_section = response[start:end]
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
start_token_pos = len(start_tokens)
end_token_pos = start_token_pos + len(state_tokens)
state_mask[i, start_token_pos:end_token_pos] = 0
loss_mask = state_mask * response_mask
batch.batch['loss_mask'] = loss_mask
# # Debug print
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
# response_ids = batch.batch['responses'][0]
# unmasked_ids = response_ids[loss_mask[0] == 0]
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
# masked_ids = response_ids[loss_mask[0] == 1]
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 1]
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 0]
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
metrics.update({
'state_tokens/total': loss_mask.sum().item(),
'state_tokens/coverage': (loss_mask.sum() / response_mask.sum()).item(),
})
return batch, metrics

View File

@@ -0,0 +1,5 @@
working_dir: ./
excludes: ["/.git/"]
env_vars:
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
VLLM_ATTENTION_BACKEND: "XFORMERS"