Initial commit
This commit is contained in:
13
verl/trainer/__init__.py
Normal file
13
verl/trainer/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
6
verl/trainer/config/evaluation.yaml
Normal file
6
verl/trainer/config/evaluation.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
data:
|
||||
path: /tmp/math_Qwen2-7B-Instruct.parquet
|
||||
prompt_key: prompt
|
||||
response_key: responses
|
||||
data_source_key: data_source
|
||||
reward_model_key: reward_model
|
||||
35
verl/trainer/config/generation.yaml
Normal file
35
verl/trainer/config/generation.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
trainer:
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
data:
|
||||
path: ~/data/rlhf/math/test.parquet
|
||||
prompt_key: prompt
|
||||
n_samples: 5
|
||||
output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
|
||||
batch_size: 128
|
||||
|
||||
model:
|
||||
path: ~/models/Qwen2-7B-Instruct
|
||||
external_lib: null
|
||||
rollout:
|
||||
name: vllm
|
||||
temperature: 1.0
|
||||
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
|
||||
top_p: 0.7
|
||||
prompt_length: 1536
|
||||
response_length: 512
|
||||
# for vllm rollout
|
||||
dtype: bfloat16 # should align with FSDP
|
||||
gpu_memory_utilization: 0.5
|
||||
ignore_eos: False
|
||||
micro_batch_size: 256
|
||||
enforce_eager: True
|
||||
free_cache_engine: True
|
||||
load_format: dummy_dtensor
|
||||
tensor_model_parallel_size: 1
|
||||
max_num_batched_tokens: 8192
|
||||
max_num_seqs: 1024
|
||||
log_prob_micro_batch_size: 8
|
||||
# for hf rollout
|
||||
do_sample: True
|
||||
148
verl/trainer/config/ppo_megatron_trainer.yaml
Normal file
148
verl/trainer/config/ppo_megatron_trainer.yaml
Normal file
@@ -0,0 +1,148 @@
|
||||
data:
|
||||
tokenizer: null
|
||||
train_files: ~/data/rlhf/gsm8k/train.parquet
|
||||
val_files: ~/data/rlhf/gsm8k/test.parquet
|
||||
prompt_key: prompt
|
||||
max_prompt_length: 512
|
||||
max_response_length: 512
|
||||
train_batch_size: 1024
|
||||
val_batch_size: 1312
|
||||
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
|
||||
return_raw_chat: False
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
external_lib: null
|
||||
override_config: {}
|
||||
enable_gradient_checkpointing: False
|
||||
actor:
|
||||
strategy: megatron # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 256
|
||||
ppo_micro_batch_size: 64
|
||||
clip_ratio: 0.2
|
||||
entropy_coeff: 0.001
|
||||
ppo_epochs: 1
|
||||
shuffle: True
|
||||
optim:
|
||||
lr: 1e-6
|
||||
clip_grad: 1.0
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
megatron:
|
||||
tensor_model_parallel_size: 4
|
||||
pipeline_model_parallel_size: 1
|
||||
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
|
||||
sequence_parallel: True
|
||||
seed: 1
|
||||
load_weight: True
|
||||
ref:
|
||||
megatron:
|
||||
tensor_model_parallel_size: 4
|
||||
pipeline_model_parallel_size: 1
|
||||
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
|
||||
sequence_parallel: True
|
||||
seed: 1
|
||||
load_weight: True
|
||||
param_offload: False
|
||||
log_prob_micro_batch_size: 32
|
||||
rollout:
|
||||
name: vllm
|
||||
temperature: 1.0
|
||||
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
||||
top_p: 1
|
||||
prompt_length: ${data.max_prompt_length} # for xperf_gpt
|
||||
response_length: ${data.max_response_length}
|
||||
# for vllm rollout
|
||||
dtype: bfloat16 # should align with FSDP
|
||||
gpu_memory_utilization: 0.5
|
||||
ignore_eos: False
|
||||
enforce_eager: True
|
||||
free_cache_engine: True
|
||||
load_format: dummy_megatron
|
||||
tensor_model_parallel_size: 2
|
||||
max_num_batched_tokens: 8192
|
||||
max_num_seqs: 1024
|
||||
log_prob_micro_batch_size: 2
|
||||
# for hf rollout
|
||||
do_sample: True
|
||||
layer_name_map:
|
||||
qkv_layer_name: qkv
|
||||
gate_proj_layer_name: gate_up
|
||||
# number of responses (i.e. num sample times)
|
||||
n: 1
|
||||
|
||||
critic:
|
||||
strategy: megatron
|
||||
optim:
|
||||
lr: 1e-5
|
||||
clip_grad: 1.0
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
tokenizer_path: ${actor_rollout_ref.model.path}
|
||||
override_config: {}
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
enable_gradient_checkpointing: False
|
||||
megatron:
|
||||
tensor_model_parallel_size: 4
|
||||
pipeline_model_parallel_size: 1
|
||||
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
|
||||
sequence_parallel: True
|
||||
seed: 1
|
||||
load_weight: True
|
||||
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
||||
ppo_micro_batch_size: 2
|
||||
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
||||
shuffle: ${actor_rollout_ref.actor.shuffle}
|
||||
cliprange_value: 0.5
|
||||
kl_ctrl:
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
|
||||
reward_model:
|
||||
enable: False
|
||||
strategy: megatron
|
||||
megatron:
|
||||
tensor_model_parallel_size: 4
|
||||
pipeline_model_parallel_size: 1
|
||||
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
|
||||
sequence_parallel: True
|
||||
seed: 1
|
||||
model:
|
||||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
||||
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
load_weight: True
|
||||
param_offload: False
|
||||
micro_batch_size: 64
|
||||
max_length: null
|
||||
|
||||
algorithm:
|
||||
gamma: 1.0
|
||||
lam: 1.0
|
||||
adv_estimator: gae
|
||||
kl_penalty: kl # how to estimate kl divergence
|
||||
kl_ctrl:
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
|
||||
trainer:
|
||||
total_epochs: 30
|
||||
total_training_steps: null
|
||||
project_name: verl_examples
|
||||
experiment_name: gsm8k
|
||||
logger: ['console', 'wandb']
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
test_freq: 2
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
|
||||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||
177
verl/trainer/config/ppo_trainer.yaml
Normal file
177
verl/trainer/config/ppo_trainer.yaml
Normal file
@@ -0,0 +1,177 @@
|
||||
data:
|
||||
tokenizer: null
|
||||
train_files: ~/data/rlhf/gsm8k/train.parquet
|
||||
val_files: ~/data/rlhf/gsm8k/test.parquet
|
||||
train_data_num: null
|
||||
val_data_num: null
|
||||
prompt_key: prompt
|
||||
max_prompt_length: 512
|
||||
max_response_length: 512
|
||||
max_start_length: 256
|
||||
max_obs_length: 512
|
||||
train_batch_size: 1024
|
||||
val_batch_size: 1312
|
||||
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
|
||||
return_raw_chat: False
|
||||
shuffle_train_dataloader: True
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
external_lib: null
|
||||
override_config: { }
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
actor:
|
||||
strategy: fsdp # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 256
|
||||
ppo_micro_batch_size: 64
|
||||
use_dynamic_bsz: False
|
||||
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
|
||||
grad_clip: 1.0
|
||||
state_masking: False
|
||||
clip_ratio: 0.2
|
||||
entropy_coeff: 0.001
|
||||
use_kl_loss: False # True for GRPO
|
||||
kl_loss_coef: 0.001 # for grpo
|
||||
kl_loss_type: low_var_kl # for grpo
|
||||
ppo_epochs: 1
|
||||
shuffle: False
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
optim:
|
||||
lr: 1e-6
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
param_offload: False
|
||||
grad_offload: False
|
||||
optimizer_offload: False
|
||||
fsdp_size: -1
|
||||
ref:
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
fsdp_size: -1
|
||||
log_prob_micro_batch_size: 128
|
||||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
|
||||
rollout:
|
||||
name: vllm
|
||||
temperature: 1.0
|
||||
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
||||
top_p: 0.95
|
||||
prompt_length: ${data.max_prompt_length} # not use for opensource
|
||||
response_length: ${data.max_response_length}
|
||||
# for vllm rollout
|
||||
dtype: bfloat16 # should align with FSDP
|
||||
gpu_memory_utilization: 0.5
|
||||
ignore_eos: False
|
||||
enforce_eager: True
|
||||
free_cache_engine: True
|
||||
load_format: dummy_dtensor
|
||||
tensor_model_parallel_size: 2
|
||||
max_num_batched_tokens: 8192
|
||||
max_num_seqs: 1024
|
||||
log_prob_micro_batch_size: 128
|
||||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||
# for hf rollout
|
||||
do_sample: True
|
||||
# number of responses (i.e. num sample times)
|
||||
n: 1 # > 1 for grpo
|
||||
n_agent: 1 # different here used for agent tasks only
|
||||
|
||||
critic:
|
||||
strategy: fsdp
|
||||
optim:
|
||||
lr: 1e-5
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
tokenizer_path: ${actor_rollout_ref.model.path}
|
||||
override_config: { }
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
grad_offload: False
|
||||
optimizer_offload: False
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
fsdp_size: -1
|
||||
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
||||
ppo_micro_batch_size: 64
|
||||
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
|
||||
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
|
||||
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
||||
shuffle: ${actor_rollout_ref.actor.shuffle}
|
||||
grad_clip: 1.0
|
||||
cliprange_value: 0.5
|
||||
|
||||
reward_model:
|
||||
enable: False
|
||||
strategy: fsdp
|
||||
model:
|
||||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
||||
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
min_num_params: 0
|
||||
param_offload: False
|
||||
micro_batch_size: 64
|
||||
max_length: null
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
||||
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
||||
|
||||
retriever:
|
||||
url: "http://127.0.0.1:8000/retrieve"
|
||||
topk: 3
|
||||
|
||||
algorithm:
|
||||
gamma: 1.0
|
||||
lam: 1.0
|
||||
adv_estimator: gae
|
||||
no_think_rl: False
|
||||
kl_penalty: kl # how to estimate kl divergence
|
||||
kl_ctrl:
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
state_masking:
|
||||
start_state_marker: "<information>"
|
||||
end_state_marker: "</information>"
|
||||
|
||||
trainer:
|
||||
total_epochs: 30
|
||||
total_training_steps: null
|
||||
project_name: verl_examples
|
||||
experiment_name: gsm8k
|
||||
logger: [ 'console', 'wandb' ]
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
test_freq: -1
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
|
||||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||
|
||||
max_turns: 10
|
||||
do_search: true
|
||||
42
verl/trainer/config/sft_trainer.yaml
Normal file
42
verl/trainer/config/sft_trainer.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
data:
|
||||
train_batch_size: 256
|
||||
micro_batch_size: 16 # this is also val batch size
|
||||
train_files: ~/data/gsm8k/train.parquet
|
||||
val_files: ~/data/gsm8k/test.parquet
|
||||
prompt_key: question
|
||||
response_key: answer
|
||||
max_length: 1024
|
||||
truncation: error
|
||||
balance_dp_token: False
|
||||
chat_template: null
|
||||
model:
|
||||
partial_pretrain: ~/models/gemma-1.1-7b-it
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
min_num_params: 0
|
||||
cpu_offload: False
|
||||
offload_params: False
|
||||
external_lib: null
|
||||
enable_gradient_checkpointing: False
|
||||
trust_remote_code: False
|
||||
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
|
||||
lora_alpha: 16 # LoRA scaling factor
|
||||
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
|
||||
optim:
|
||||
lr: 1e-5
|
||||
betas: [0.9, 0.95]
|
||||
weight_decay: 0.01
|
||||
warmup_steps_ratio: 0.1
|
||||
clip_grad: 1.0
|
||||
|
||||
trainer:
|
||||
default_local_dir: /tmp/sft_model
|
||||
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
|
||||
resume_path: null
|
||||
project_name: gsm8k-sft
|
||||
experiment_name: test
|
||||
total_epochs: 4
|
||||
total_training_steps: null
|
||||
validate_before_training: False
|
||||
logger: ['console']
|
||||
seed: 1
|
||||
435
verl/trainer/fsdp_sft_trainer.py
Normal file
435
verl/trainer/fsdp_sft_trainer.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A lightweight one-file FSDP SFT Trainer
|
||||
TODO(zhangchi.usc1992)
|
||||
- Add calculation of mfu
|
||||
- Add validation
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ['NCCL_DEBUG'] = 'WARN'
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
||||
|
||||
import logging
|
||||
import re
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn, optim
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
|
||||
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
|
||||
from tensordict import TensorDict
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
|
||||
from verl.utils.dataset import SFTDataset
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.utils.tracking import Tracking
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
import verl.utils.hdfs_io as hdfs_io
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
|
||||
|
||||
|
||||
def extract_step(path):
|
||||
match = re.search(r'global_step_(\d+)', path)
|
||||
if match:
|
||||
return int(match.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_regular_types(obj):
|
||||
"""Convert Hydra configs and other special types to regular Python types."""
|
||||
from omegaconf import ListConfig, DictConfig
|
||||
if isinstance(obj, (ListConfig, DictConfig)):
|
||||
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [convert_to_regular_types(x) for x in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_to_regular_types(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class FSDPSFTTrainer(object):
|
||||
|
||||
def __init__(self, config, device_mesh: DeviceMesh):
|
||||
self.config = config
|
||||
self.device_mesh = device_mesh
|
||||
# build tokenizer first
|
||||
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
|
||||
from verl.utils import hf_tokenizer
|
||||
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
|
||||
if self.config.data.chat_template is not None:
|
||||
raise ValueError('Apply Chat template from config is not supported yet.')
|
||||
|
||||
# normalize dp size
|
||||
self._normalize_config_bsz()
|
||||
|
||||
self._build_dataloader()
|
||||
# build model
|
||||
self._build_model_optimizer()
|
||||
|
||||
# TODO: add checkpoint manager
|
||||
if self.device_mesh.get_rank() == 0:
|
||||
print(self.config)
|
||||
|
||||
def _normalize_config_bsz(self):
|
||||
dp_size = self.device_mesh.size()
|
||||
if self.device_mesh.get_rank() == 0:
|
||||
print(f'Normalize batch size by dp {dp_size}')
|
||||
|
||||
assert self.config.data.train_batch_size % dp_size == 0
|
||||
assert self.config.data.micro_batch_size % dp_size == 0
|
||||
|
||||
self.config.data.train_batch_size //= dp_size
|
||||
self.config.data.micro_batch_size //= dp_size
|
||||
|
||||
def _build_dataloader(self):
|
||||
config = self.config
|
||||
# build dataset
|
||||
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=config.data.prompt_key,
|
||||
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
|
||||
response_key=config.data.response_key,
|
||||
response_dict_keys=config.data.get('response_dict_keys', None),
|
||||
max_length=config.data.max_length,
|
||||
truncation=config.data.truncation)
|
||||
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=config.data.prompt_key,
|
||||
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
|
||||
response_key=config.data.response_key,
|
||||
response_dict_keys=config.data.get('response_dict_keys', None),
|
||||
max_length=config.data.max_length,
|
||||
truncation=config.data.truncation)
|
||||
|
||||
# build dataloader
|
||||
rank = self.device_mesh.get_rank()
|
||||
world_size = self.device_mesh.size()
|
||||
self.train_sampler = DistributedSampler(self.train_dataset,
|
||||
shuffle=True,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
drop_last=True)
|
||||
self.train_dataloader = DataLoader(dataset=self.train_dataset,
|
||||
batch_size=config.data.train_batch_size,
|
||||
sampler=self.train_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
self.val_sampler = DistributedSampler(self.val_dataset,
|
||||
shuffle=True,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
drop_last=True)
|
||||
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
||||
batch_size=config.data.micro_batch_size,
|
||||
sampler=self.val_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
def _build_model_optimizer(self):
|
||||
# TODO (zhangchi.usc1992):
|
||||
# 1. support pretrain from random weights
|
||||
# 2. support init directly from sharded weights
|
||||
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
|
||||
|
||||
if self.config.model.get('external_lib', None) is not None:
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import importlib
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
|
||||
log_gpu_memory_usage('Before model allocation', logger=logger)
|
||||
|
||||
trust_remote_code = self.config.model.trust_remote_code
|
||||
# load config first
|
||||
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
# This may be very large
|
||||
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
|
||||
|
||||
with init_context():
|
||||
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
attn_implementation='flash_attention_2',
|
||||
trust_remote_code=trust_remote_code)
|
||||
if self.config.model.get('lora_rank', 0) > 0:
|
||||
self.model.enable_input_require_grads()
|
||||
# Convert config to regular Python types before creating PEFT model
|
||||
lora_config = {
|
||||
'task_type': TaskType.CAUSAL_LM,
|
||||
'r': self.config.model.lora_rank,
|
||||
'lora_alpha': self.config.model.lora_alpha,
|
||||
'target_modules': convert_to_regular_types(self.config.model.target_modules),
|
||||
'bias': "none"
|
||||
}
|
||||
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
|
||||
|
||||
if self.config.model.enable_gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
|
||||
|
||||
log_gpu_memory_usage('After model allocation', logger=logger)
|
||||
|
||||
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32)
|
||||
|
||||
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
|
||||
config=self.config.model.fsdp_config.wrap_policy,
|
||||
is_lora=self.config.model.get('lora_rank', 0) > 0)
|
||||
if self.device_mesh.get_rank() == 0:
|
||||
print(auto_wrap_policy)
|
||||
|
||||
if not self.config.model.fsdp_config.cpu_offload:
|
||||
cpu_offload = None
|
||||
else:
|
||||
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
|
||||
|
||||
self.fsdp_model = FSDP(module=self.model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
param_init_fn=init_fn,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
mixed_precision=mixed_precision,
|
||||
device_mesh=self.device_mesh,
|
||||
sync_module_states=True,
|
||||
device_id=torch.cuda.current_device(),
|
||||
cpu_offload=cpu_offload,
|
||||
use_orig_params=False)
|
||||
|
||||
log_gpu_memory_usage('After FSDP wrapping', logger=logger)
|
||||
|
||||
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
|
||||
lr=self.config.optim.lr,
|
||||
betas=self.config.optim.betas,
|
||||
weight_decay=self.config.optim.weight_decay)
|
||||
|
||||
log_gpu_memory_usage('After initialize optimizer', logger=logger)
|
||||
|
||||
steps_per_epoch = len(self.train_dataloader)
|
||||
total_steps = steps_per_epoch * self.config.trainer.total_epochs
|
||||
|
||||
if self.device_mesh.get_rank() == 0:
|
||||
print(
|
||||
f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}'
|
||||
)
|
||||
|
||||
num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio)
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps)
|
||||
|
||||
def _compute_loss(self, batch):
|
||||
loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda()
|
||||
labels = batch['input_ids'][:, 1:].cuda()
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.fsdp_model(input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
position_ids=batch['position_ids'],
|
||||
use_cache=False) # prevent model thinks it it generating
|
||||
|
||||
logits = output.logits
|
||||
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = loss * loss_mask
|
||||
|
||||
valid_token_this_rank = torch.sum(loss_mask)
|
||||
|
||||
if self.config.data.balance_dp_token:
|
||||
torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks
|
||||
dp_size = torch.distributed.get_world_size()
|
||||
else:
|
||||
dp_size = 1
|
||||
|
||||
loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp
|
||||
return loss
|
||||
|
||||
def training_step(self, batch: TensorDict):
|
||||
self.fsdp_model.train()
|
||||
|
||||
log_gpu_memory_usage('Before optimizer zero_grad', logger=logger)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
log_gpu_memory_usage('After optimizer zero_grad', logger=logger)
|
||||
|
||||
micro_batches = batch.split(self.config.data.micro_batch_size)
|
||||
n_micro_batches = len(micro_batches)
|
||||
step_loss = 0
|
||||
for micro_batch in micro_batches:
|
||||
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
|
||||
loss.backward()
|
||||
step_loss += loss.item()
|
||||
|
||||
self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
|
||||
|
||||
log_gpu_memory_usage('Before optimizer step', logger=logger)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
log_gpu_memory_usage('After optimizer step', logger=logger)
|
||||
|
||||
self.lr_scheduler.step()
|
||||
|
||||
# reduce loss across dp ranks
|
||||
lr = self.lr_scheduler.get_last_lr()[0]
|
||||
|
||||
log_gpu_memory_usage('After offload weights', logger=logger)
|
||||
|
||||
step_loss = torch.tensor(step_loss).cuda()
|
||||
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
|
||||
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
|
||||
|
||||
def validation_step(self, batch: TensorDict):
|
||||
self.fsdp_model.eval()
|
||||
with torch.no_grad():
|
||||
loss = self._compute_loss(batch)
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
|
||||
return loss
|
||||
|
||||
def save_checkpoint(self, step):
|
||||
# save checkpoint
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
state_dict = self.fsdp_model.state_dict()
|
||||
|
||||
path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
|
||||
# save huggingface model
|
||||
if self.device_mesh.get_rank() == 0:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
if self.config.trainer.default_hdfs_dir:
|
||||
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
|
||||
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
|
||||
torch.distributed.barrier()
|
||||
|
||||
def fit(self):
|
||||
rank = self.device_mesh.get_rank()
|
||||
|
||||
# TODO: add a unified tracking
|
||||
if rank == 0:
|
||||
tracking = Tracking(project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
default_backend=self.config.trainer.logger)
|
||||
|
||||
global_step = 0
|
||||
# compute the total training steps.
|
||||
# the total training steps in SFT is mainly for early exit
|
||||
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
||||
|
||||
if self.config.trainer.total_training_steps is not None:
|
||||
total_training_steps = self.config.trainer.total_training_steps
|
||||
|
||||
self.total_training_steps = total_training_steps
|
||||
print(f'Total training steps: {self.total_training_steps}')
|
||||
|
||||
# TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow.
|
||||
|
||||
if self.config.trainer.validate_before_training:
|
||||
# validate before training
|
||||
val_losses = []
|
||||
for data in self.val_dataloader:
|
||||
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
|
||||
val_loss = self.validation_step(data)
|
||||
val_losses.append(val_loss)
|
||||
if rank == 0:
|
||||
val_loss = torch.mean(torch.stack(val_losses))
|
||||
metric = {'val/loss': val_loss.detach().item()}
|
||||
tracking.log(data=metric, step=global_step)
|
||||
torch.distributed.barrier()
|
||||
|
||||
for epoch in range(self.config.trainer.total_epochs):
|
||||
self.train_sampler.set_epoch(epoch=epoch)
|
||||
for data in self.train_dataloader:
|
||||
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
|
||||
metric = self.training_step(data)
|
||||
if rank == 0:
|
||||
tracking.log(data=metric, step=global_step)
|
||||
global_step += 1
|
||||
|
||||
# for early exit validation
|
||||
if global_step >= self.total_training_steps:
|
||||
# Perform final validation
|
||||
val_losses = []
|
||||
for val_data in self.val_dataloader:
|
||||
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size).cuda()
|
||||
val_loss = self.validation_step(val_data)
|
||||
val_losses.append(val_loss)
|
||||
if rank == 0:
|
||||
avg_val_loss = torch.mean(torch.stack(val_losses))
|
||||
metric = {'val/loss': avg_val_loss.detach().item()}
|
||||
tracking.log(data=metric, step=global_step)
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Save final checkpoint
|
||||
self.save_checkpoint(step=global_step)
|
||||
return
|
||||
|
||||
# validation
|
||||
val_losses = []
|
||||
for data in self.val_dataloader:
|
||||
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
|
||||
val_loss = self.validation_step(data)
|
||||
val_losses.append(val_loss)
|
||||
if rank == 0:
|
||||
val_loss = torch.mean(torch.stack(val_losses))
|
||||
metric = {'val/loss': val_loss.detach().item()}
|
||||
tracking.log(data=metric, step=global_step)
|
||||
torch.distributed.barrier()
|
||||
|
||||
# save checkpoint
|
||||
self.save_checkpoint(step=global_step)
|
||||
|
||||
|
||||
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
|
||||
import hydra
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from verl.utils.distributed import initialize_global_process_group
|
||||
|
||||
|
||||
@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
|
||||
def main(config):
|
||||
local_rank, rank, world_size = initialize_global_process_group()
|
||||
|
||||
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
|
||||
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
69
verl/trainer/main_eval.py
Normal file
69
verl/trainer/main_eval.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
|
||||
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
|
||||
|
||||
"""
|
||||
|
||||
import hydra
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.utils.reward_score import math, gsm8k
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def select_reward_fn(data_source):
|
||||
if data_source == 'lighteval/MATH':
|
||||
return math.compute_score
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
|
||||
def main(config):
|
||||
local_path = copy_local_path_from_hdfs(config.data.path)
|
||||
dataset = pd.read_parquet(local_path)
|
||||
prompts = dataset[config.data.prompt_key]
|
||||
responses = dataset[config.data.response_key]
|
||||
data_sources = dataset[config.data.data_source_key]
|
||||
reward_model_data = dataset[config.data.reward_model_key]
|
||||
|
||||
passes = 0
|
||||
|
||||
total = len(dataset)
|
||||
|
||||
for i in range(total):
|
||||
response_lst = responses[i]
|
||||
data_source = data_sources[i]
|
||||
# select reward score based on data_source
|
||||
prompt = prompts[i]
|
||||
reward_data = reward_model_data[i]
|
||||
reward_fn = select_reward_fn(data_source)
|
||||
ground_truth = reward_data['ground_truth']
|
||||
score_lst = []
|
||||
for r in response_lst:
|
||||
score = reward_fn(r, ground_truth)
|
||||
score_lst.append(score)
|
||||
|
||||
max_score = np.max(score_lst)
|
||||
|
||||
if max_score == 1:
|
||||
passes += 1
|
||||
|
||||
print(f'pass@5: {passes / total}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
137
verl/trainer/main_generation.py
Normal file
137
verl/trainer/main_generation.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Generate responses given a dataset of prompts
|
||||
"""
|
||||
import ray
|
||||
import numpy as np
|
||||
import hydra
|
||||
import os
|
||||
|
||||
os.environ['NCCL_DEBUG'] = 'WARN'
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
||||
# os.environ['TORCH_COMPILE_DISABLE'] = '1'
|
||||
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
||||
from verl.utils.hdfs_io import makedirs
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
||||
|
||||
|
||||
@hydra.main(config_path='config', config_name='generation', version_base=None)
|
||||
def main(config):
|
||||
from pprint import pprint
|
||||
from omegaconf import OmegaConf
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
local_path = copy_local_path_from_hdfs(config.model.path)
|
||||
from verl.utils import hf_tokenizer
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
if config.rollout.temperature == 0.:
|
||||
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'
|
||||
|
||||
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
|
||||
dataset = pd.read_parquet(config.data.path)
|
||||
chat_lst = dataset[config.data.prompt_key].tolist()
|
||||
|
||||
chat_lst = [chat.tolist() for chat in chat_lst]
|
||||
|
||||
tokenizer.padding_side = 'left'
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout')
|
||||
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
|
||||
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
|
||||
wg.init_model()
|
||||
|
||||
total_samples = len(dataset)
|
||||
# real_batch_size = data.batch['input_ids'].shape[0]
|
||||
config_batch_size = config.data.batch_size
|
||||
dp_size = wg.world_size // config.rollout.tensor_model_parallel_size
|
||||
num_batch = (total_samples // config_batch_size) + 1
|
||||
output_lst = [[] for _ in range(config.data.n_samples)]
|
||||
|
||||
for batch_idx in range(num_batch):
|
||||
print(f'[{batch_idx+1}/{num_batch}] Start to process.')
|
||||
batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size]
|
||||
inputs = tokenizer.apply_chat_template(batch_chat_lst,
|
||||
add_generation_prompt=True,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=config.rollout.prompt_length,
|
||||
return_tensors='pt',
|
||||
return_dict=True,
|
||||
tokenize=True)
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
|
||||
|
||||
data = DataProto.from_dict(batch_dict)
|
||||
real_batch_size = data.batch['input_ids'].shape[0]
|
||||
if real_batch_size % dp_size != 0:
|
||||
dummy_data_size = dp_size - real_batch_size % dp_size
|
||||
dummy_data = data[:dummy_data_size]
|
||||
data = DataProto.concat([data, dummy_data])
|
||||
print(
|
||||
f'dp_size {dp_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data'
|
||||
)
|
||||
|
||||
batch_size = data.batch['input_ids'].shape[0]
|
||||
assert batch_size % dp_size == 0, f'batch_size {batch_size} is not divisible by dp_size {dp_size}'
|
||||
|
||||
print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
|
||||
# START TO GENERATE FOR n_samples TIMES
|
||||
for i in range(config.data.n_samples):
|
||||
output = wg.generate_sequences(data)
|
||||
# remove dummy data
|
||||
output = output[:real_batch_size]
|
||||
output_text = tokenizer.batch_decode(output.batch['input_ids'][:, -config.rollout.response_length:],
|
||||
skip_special_tokens=False)
|
||||
|
||||
# remove the padding
|
||||
pad_token = tokenizer.pad_token
|
||||
output_text_unpad = []
|
||||
for text in output_text:
|
||||
output_text_unpad.append(text.replace(pad_token, ''))
|
||||
|
||||
output_lst[i].extend(output_text_unpad)
|
||||
|
||||
# convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
|
||||
output_lst = np.array(output_lst, dtype=object)
|
||||
output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()
|
||||
|
||||
# add to the data frame
|
||||
dataset[f'responses'] = output_lst
|
||||
|
||||
# write to a new parquet
|
||||
output_dir = os.path.dirname(config.data.output_path)
|
||||
makedirs(output_dir, exist_ok=True)
|
||||
dataset.to_parquet(config.data.output_path)
|
||||
|
||||
return output_text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
202
verl/trainer/main_ppo.py
Normal file
202
verl/trainer/main_ppo.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
||||
"""
|
||||
|
||||
from verl import DataProto
|
||||
import torch
|
||||
from verl.utils.reward_score import qa_em
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
def _select_rm_score_fn(data_source):
|
||||
if "nq" in data_source:
|
||||
return qa_em.compute_score_em
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RewardManager():
|
||||
"""The reward manager.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, num_examine, format_score=0.) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
|
||||
self.format_score = format_score
|
||||
|
||||
def __call__(self, data: DataProto):
|
||||
"""We will expand this function gradually based on the available datasets"""
|
||||
|
||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||
if 'rm_scores' in data.batch.keys():
|
||||
return data.batch['rm_scores']
|
||||
|
||||
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
|
||||
|
||||
# all_scores = []
|
||||
|
||||
already_print_data_sources = {}
|
||||
|
||||
for i in range(len(data)):
|
||||
data_item = data[i] # DataProtoItem
|
||||
|
||||
prompt_ids = data_item.batch['prompts']
|
||||
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
|
||||
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
|
||||
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
|
||||
|
||||
response_ids = data_item.batch['responses']
|
||||
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
|
||||
valid_response_ids = response_ids[:valid_response_length]
|
||||
|
||||
# decode
|
||||
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
|
||||
sequences_str = self.tokenizer.decode(sequences)
|
||||
|
||||
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
|
||||
|
||||
# select rm_score
|
||||
data_source = data_item.non_tensor_batch['data_source']
|
||||
compute_score_fn = _select_rm_score_fn(data_source)
|
||||
|
||||
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, format_score=self.format_score)
|
||||
|
||||
reward_tensor[i, valid_response_length - 1] = score
|
||||
# all_scores.append(score)
|
||||
|
||||
if data_source not in already_print_data_sources:
|
||||
already_print_data_sources[data_source] = 0
|
||||
|
||||
if already_print_data_sources[data_source] < self.num_examine:
|
||||
already_print_data_sources[data_source] += 1
|
||||
print(sequences_str)
|
||||
|
||||
# print(f"[DEBUG] all_scores: {all_scores}")
|
||||
# print(f"[DEBUG] all_scores shape: {np.array(all_scores).shape}")
|
||||
# print(f"[DEBUG] all_scores mean: {np.mean(all_scores)}")
|
||||
# print(f"[DEBUG] all_scores max: {np.max(all_scores)}")
|
||||
# print(f"[DEBUG] all_scores min: {np.min(all_scores)}")
|
||||
# print(f"[DEBUG] all_scores std: {np.std(all_scores)}")
|
||||
|
||||
return reward_tensor
|
||||
|
||||
|
||||
import ray
|
||||
import hydra
|
||||
|
||||
|
||||
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
|
||||
def main(config):
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
|
||||
|
||||
ray.get(main_task.remote(config))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def main_task(config):
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# print initial config
|
||||
from pprint import pprint
|
||||
from omegaconf import OmegaConf
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# env_class = ENV_CLASS_MAPPING[config.env.name]
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_tokenizer
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy == 'fsdp':
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
ray_worker_group_cls = RayWorkerGroup
|
||||
|
||||
elif config.actor_rollout_ref.actor.strategy == 'megatron':
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
||||
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
||||
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
||||
|
||||
role_worker_mapping = {
|
||||
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
||||
Role.Critic: ray.remote(CriticWorker),
|
||||
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
|
||||
}
|
||||
|
||||
global_pool_id = 'global_pool'
|
||||
resource_pool_spec = {
|
||||
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
||||
}
|
||||
mapping = {
|
||||
Role.ActorRollout: global_pool_id,
|
||||
Role.Critic: global_pool_id,
|
||||
Role.RefPolicy: global_pool_id,
|
||||
}
|
||||
|
||||
# we should adopt a multi-source reward function here
|
||||
# - for rule-based rm, we directly call a reward score
|
||||
# - for model-based rm, we call a model
|
||||
# - for code related prompt, we send to a sandbox if there are test cases
|
||||
# - finally, we combine all the rewards together
|
||||
# - The reward type depends on the tag of the data
|
||||
if config.reward_model.enable:
|
||||
if config.reward_model.strategy == 'fsdp':
|
||||
from verl.workers.fsdp_workers import RewardModelWorker
|
||||
elif config.reward_model.strategy == 'megatron':
|
||||
from verl.workers.megatron_workers import RewardModelWorker
|
||||
else:
|
||||
raise NotImplementedError
|
||||
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
||||
mapping[Role.RewardModel] = global_pool_id
|
||||
|
||||
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
|
||||
|
||||
# Note that we always use function-based RM for validation
|
||||
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)
|
||||
|
||||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
trainer = RayPPOTrainer(config=config,
|
||||
tokenizer=tokenizer,
|
||||
role_worker_mapping=role_worker_mapping,
|
||||
resource_pool_manager=resource_pool_manager,
|
||||
ray_worker_group_cls=ray_worker_group_cls,
|
||||
reward_fn=reward_fn,
|
||||
val_reward_fn=val_reward_fn,
|
||||
)
|
||||
trainer.init_workers()
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
13
verl/trainer/ppo/__init__.py
Normal file
13
verl/trainer/ppo/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
274
verl/trainer/ppo/core_algos.py
Normal file
274
verl/trainer/ppo/core_algos.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Core functions to implement PPO algorithms.
|
||||
The function implemented in this file should be used by trainer with different distributed strategies to
|
||||
implement PPO
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
class AdaptiveKLController:
|
||||
"""
|
||||
Adaptive KL controller described in the paper:
|
||||
https://arxiv.org/pdf/1909.08593.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, init_kl_coef, target_kl, horizon):
|
||||
self.value = init_kl_coef
|
||||
self.target = target_kl
|
||||
self.horizon = horizon
|
||||
|
||||
def update(self, current_kl, n_steps):
|
||||
target = self.target
|
||||
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
||||
mult = 1 + proportional_error * n_steps / self.horizon
|
||||
self.value *= mult
|
||||
|
||||
|
||||
class FixedKLController:
|
||||
"""Fixed KL controller."""
|
||||
|
||||
def __init__(self, kl_coef):
|
||||
self.value = kl_coef
|
||||
|
||||
def update(self, current_kl, n_steps):
|
||||
pass
|
||||
|
||||
|
||||
def get_kl_controller(config): # seems never used?
|
||||
if config.critic.kl_ctrl.type == 'fixed':
|
||||
kl_ctrl = FixedKLController(kl_coef=config.critic.kl_ctrl.kl_coef)
|
||||
elif config.critic.kl_ctrl.type == 'adaptive':
|
||||
assert config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
|
||||
kl_ctrl = AdaptiveKLController(init_kl_coef=config.critic.kl_ctrl.kl_coef,
|
||||
target_kl=config.critic.kl_ctrl.target_kl,
|
||||
horizon=config.critic.kl_ctrl.horizon)
|
||||
else:
|
||||
raise ValueError('Unknown kl_ctrl type')
|
||||
|
||||
return kl_ctrl
|
||||
|
||||
|
||||
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
|
||||
gamma: torch.Tensor, lam: torch.Tensor):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
|
||||
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
values: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
|
||||
gamma: `(float)`
|
||||
discounted factor used in RL
|
||||
lam: `(float)`
|
||||
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
||||
|
||||
Returns:
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
Returns: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
"""
|
||||
with torch.no_grad():
|
||||
lastgaelam = 0
|
||||
advantages_reversed = []
|
||||
gen_len = token_level_rewards.shape[-1]
|
||||
|
||||
for t in reversed(range(gen_len)):
|
||||
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
||||
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
|
||||
lastgaelam = delta + gamma * lam * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
returns = advantages + values
|
||||
advantages = verl_F.masked_whiten(advantages, eos_mask)
|
||||
return advantages, returns
|
||||
|
||||
|
||||
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
|
||||
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
eos_mask: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
epsilon: float = 1e-6):
|
||||
"""
|
||||
Compute advantage for GRPO, operating only on Outcome reward
|
||||
(with only one scalar reward for each response).
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
Returns: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
"""
|
||||
response_length = token_level_rewards.shape[-1]
|
||||
non_zero_mask = (token_level_rewards != 0)
|
||||
scores = (token_level_rewards * non_zero_mask).sum(dim=-1)
|
||||
|
||||
id2score = defaultdict(list)
|
||||
id2mean = {}
|
||||
id2std = {}
|
||||
|
||||
with torch.no_grad():
|
||||
bsz = scores.shape[0]
|
||||
for i in range(bsz):
|
||||
id2score[index[i]].append(scores[i])
|
||||
for idx in id2score:
|
||||
if len(id2score[idx]) == 1:
|
||||
id2mean[idx] = torch.tensor(0.0)
|
||||
id2std[idx] = torch.tensor(1.0)
|
||||
elif len(id2score[idx]) > 1:
|
||||
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
||||
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
|
||||
else:
|
||||
raise ValueError(f"no score in prompt index: {idx}")
|
||||
for i in range(bsz):
|
||||
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
|
||||
|
||||
return scores, scores
|
||||
|
||||
|
||||
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
||||
kl = old_log_prob - ref_log_prob
|
||||
return token_level_scores - kl * kl_ratio
|
||||
|
||||
|
||||
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
|
||||
|
||||
Args:
|
||||
old_log_prob: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
log_prob: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
cliprange: (float)
|
||||
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
||||
|
||||
Returns:
|
||||
pg_loss: `a scalar torch.Tensor`
|
||||
policy gradient loss computed via PPO
|
||||
pg_clipfrac: (float)
|
||||
a float number indicating the fraction of policy gradient loss being clipped
|
||||
|
||||
"""
|
||||
negative_approx_kl = log_prob - old_log_prob
|
||||
ratio = torch.exp(negative_approx_kl)
|
||||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
|
||||
|
||||
pg_losses = -advantages * ratio
|
||||
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
|
||||
|
||||
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
|
||||
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
|
||||
return pg_loss, pg_clipfrac, ppo_kl
|
||||
|
||||
|
||||
def compute_entropy_loss(logits, eos_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
entropy: a scalar torch.Tensor
|
||||
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
return entropy_loss
|
||||
|
||||
|
||||
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
|
||||
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
|
||||
|
||||
Args:
|
||||
vpreds (`torch.FloatTensor`):
|
||||
Predicted values of the value head, shape (`batch_size`, `response_length`)
|
||||
values (`torch.FloatTensor`):
|
||||
Old values of value head, shape (`batch_size`, `response_length`)
|
||||
returns: (`torch.FloatTensor`):
|
||||
Ground truth returns, shape (`batch_size`, `response_length`)
|
||||
|
||||
Returns:
|
||||
vf_loss: a scalar (`torch.FloatTensor`):
|
||||
value function loss
|
||||
vf_clipfrac: a float
|
||||
The ratio of vf being clipped
|
||||
|
||||
"""
|
||||
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
|
||||
vf_losses1 = (vpreds - returns)**2
|
||||
vf_losses2 = (vpredclipped - returns)**2
|
||||
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
|
||||
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
|
||||
return vf_loss, vf_clipfrac
|
||||
|
||||
|
||||
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
|
||||
"""Compute KL divergence given logprob and ref_logprob.
|
||||
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
|
||||
|
||||
Args:
|
||||
logprob:
|
||||
ref_logprob:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if kl_penalty == "kl":
|
||||
return logprob - ref_logprob
|
||||
|
||||
if kl_penalty == "abs":
|
||||
return (logprob - ref_logprob).abs()
|
||||
|
||||
if kl_penalty == "mse":
|
||||
return 0.5 * (logprob - ref_logprob).square()
|
||||
|
||||
# J. Schulman. Approximating kl divergence, 2020.
|
||||
# # URL http://joschu.net/blog/kl-approx.html.
|
||||
if kl_penalty == 'low_var_kl':
|
||||
kl = ref_logprob - logprob
|
||||
ratio = torch.exp(kl)
|
||||
kld = (ratio - kl - 1).contiguous()
|
||||
return torch.clamp(kld, min=-10, max=10)
|
||||
|
||||
if kl_penalty == "full":
|
||||
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
|
||||
raise NotImplementedError
|
||||
|
||||
raise NotImplementedError
|
||||
920
verl/trainer/ppo/ray_trainer.py
Normal file
920
verl/trainer/ppo/ray_trainer.py
Normal file
@@ -0,0 +1,920 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
FSDP PPO Trainer with Ray-based single controller.
|
||||
This trainer supports model-agonistic model initialization with huggingface
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Type, Dict
|
||||
|
||||
import re
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from codetiming import Timer
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from verl import DataProto
|
||||
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
|
||||
from verl.single_controller.ray.base import create_colocated_worker_cls
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
||||
|
||||
import re
|
||||
from search_r1.llm_agent.generation import LLMGenerationManager, GenerationConfig
|
||||
|
||||
WorkerType = Type[Worker]
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""
|
||||
To create more roles dynamically, you can subclass Role and add new members
|
||||
"""
|
||||
Actor = 0
|
||||
Rollout = 1
|
||||
ActorRollout = 2
|
||||
Critic = 3
|
||||
RefPolicy = 4
|
||||
RewardModel = 5
|
||||
ActorRolloutRef = 6
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourcePoolManager:
|
||||
"""
|
||||
Define a resource pool specification. Resource pool will be initialized first.
|
||||
Mapping
|
||||
"""
|
||||
resource_pool_spec: dict[str, list[int]]
|
||||
mapping: dict[Role, str]
|
||||
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
|
||||
|
||||
def create_resource_pool(self):
|
||||
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
|
||||
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
|
||||
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
|
||||
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
|
||||
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
|
||||
use_gpu=True,
|
||||
max_colocate_count=1,
|
||||
name_prefix=resource_pool_name)
|
||||
self.resource_pool_dict[resource_pool_name] = resource_pool
|
||||
|
||||
def get_resource_pool(self, role: Role) -> RayResourcePool:
|
||||
"""Get the resource pool of the worker_cls"""
|
||||
return self.resource_pool_dict[self.mapping[role]]
|
||||
|
||||
|
||||
import torch
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
|
||||
|
||||
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'):
|
||||
responses = data.batch['responses']
|
||||
response_length = responses.size(1)
|
||||
token_level_scores = data.batch['token_level_scores']
|
||||
batch_size = data.batch.batch_size[0]
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
# compute kl between ref_policy and current policy
|
||||
if 'ref_log_prob' in data.batch.keys():
|
||||
kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
|
||||
kl_penalty=kl_penalty) # (batch_size, response_length)
|
||||
kld = kld * response_mask
|
||||
beta = kl_ctrl.value
|
||||
else:
|
||||
beta = 0
|
||||
kld = torch.zeros_like(response_mask, dtype=torch.float32)
|
||||
|
||||
token_level_rewards = token_level_scores - beta * kld
|
||||
|
||||
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
|
||||
current_kl = torch.mean(current_kl, dim=0).item()
|
||||
|
||||
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
|
||||
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
|
||||
data.batch['token_level_rewards'] = token_level_rewards
|
||||
|
||||
metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta}
|
||||
|
||||
return data, metrics
|
||||
|
||||
|
||||
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
|
||||
# prepare response group
|
||||
# TODO: add other ways to estimate advantages
|
||||
if adv_estimator == 'gae':
|
||||
values = data.batch['values']
|
||||
responses = data.batch['responses']
|
||||
response_length = responses.size(-1)
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
token_level_rewards = data.batch['token_level_rewards']
|
||||
advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards,
|
||||
values=values,
|
||||
eos_mask=response_mask,
|
||||
gamma=gamma,
|
||||
lam=lam)
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
elif adv_estimator == 'grpo':
|
||||
token_level_rewards = data.batch['token_level_rewards']
|
||||
index = data.non_tensor_batch['uid']
|
||||
responses = data.batch['responses']
|
||||
response_length = responses.size(-1)
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards,
|
||||
eos_mask=response_mask,
|
||||
index=index)
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return data
|
||||
|
||||
|
||||
def reduce_metrics(metrics: dict):
|
||||
for key, val in metrics.items():
|
||||
metrics[key] = np.mean(val)
|
||||
return metrics
|
||||
|
||||
|
||||
def _compute_response_info(batch):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
prompt_length = prompt_mask.sum(-1).float()
|
||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||
|
||||
return dict(
|
||||
response_mask=response_mask,
|
||||
prompt_length=prompt_length,
|
||||
response_length=response_length,
|
||||
)
|
||||
|
||||
|
||||
def compute_data_metrics(batch, use_critic=True):
|
||||
# TODO: add response length
|
||||
sequence_score = batch.batch['token_level_scores'].sum(-1)
|
||||
sequence_reward = batch.batch['token_level_rewards'].sum(-1)
|
||||
|
||||
advantages = batch.batch['advantages']
|
||||
returns = batch.batch['returns']
|
||||
|
||||
max_response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
|
||||
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()
|
||||
|
||||
max_prompt_length = prompt_mask.size(-1)
|
||||
|
||||
response_info = _compute_response_info(batch)
|
||||
prompt_length = response_info['prompt_length']
|
||||
response_length = response_info['response_length']
|
||||
|
||||
valid_adv = torch.masked_select(advantages, response_mask)
|
||||
valid_returns = torch.masked_select(returns, response_mask)
|
||||
|
||||
if use_critic:
|
||||
values = batch.batch['values']
|
||||
valid_values = torch.masked_select(values, response_mask)
|
||||
return_diff_var = torch.var(valid_returns - valid_values)
|
||||
return_var = torch.var(valid_returns)
|
||||
|
||||
metrics = {
|
||||
# score
|
||||
'critic/score/mean':
|
||||
torch.mean(sequence_score).detach().item(),
|
||||
'critic/score/max':
|
||||
torch.max(sequence_score).detach().item(),
|
||||
'critic/score/min':
|
||||
torch.min(sequence_score).detach().item(),
|
||||
# reward
|
||||
'critic/rewards/mean':
|
||||
torch.mean(sequence_reward).detach().item(),
|
||||
'critic/rewards/max':
|
||||
torch.max(sequence_reward).detach().item(),
|
||||
'critic/rewards/min':
|
||||
torch.min(sequence_reward).detach().item(),
|
||||
# adv
|
||||
'critic/advantages/mean':
|
||||
torch.mean(valid_adv).detach().item(),
|
||||
'critic/advantages/max':
|
||||
torch.max(valid_adv).detach().item(),
|
||||
'critic/advantages/min':
|
||||
torch.min(valid_adv).detach().item(),
|
||||
# returns
|
||||
'critic/returns/mean':
|
||||
torch.mean(valid_returns).detach().item(),
|
||||
'critic/returns/max':
|
||||
torch.max(valid_returns).detach().item(),
|
||||
'critic/returns/min':
|
||||
torch.min(valid_returns).detach().item(),
|
||||
**({
|
||||
# values
|
||||
'critic/values/mean': torch.mean(valid_values).detach().item(),
|
||||
'critic/values/max': torch.max(valid_values).detach().item(),
|
||||
'critic/values/min': torch.min(valid_values).detach().item(),
|
||||
# vf explained var
|
||||
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
|
||||
} if use_critic else {}),
|
||||
|
||||
# response length
|
||||
'response_length/mean':
|
||||
torch.mean(response_length).detach().item(),
|
||||
'response_length/max':
|
||||
torch.max(response_length).detach().item(),
|
||||
'response_length/min':
|
||||
torch.min(response_length).detach().item(),
|
||||
'response_length/clip_ratio':
|
||||
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
|
||||
# prompt length
|
||||
'prompt_length/mean':
|
||||
torch.mean(prompt_length).detach().item(),
|
||||
'prompt_length/max':
|
||||
torch.max(prompt_length).detach().item(),
|
||||
'prompt_length/min':
|
||||
torch.min(prompt_length).detach().item(),
|
||||
'prompt_length/clip_ratio':
|
||||
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
||||
|
||||
# metrics for actions
|
||||
# 'metric/total_env':
|
||||
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
|
||||
# 'metric/finished_env':
|
||||
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
|
||||
# 'metric/traj_length':
|
||||
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
|
||||
# 'metric/valid_action':
|
||||
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action_ratio':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_timing_metrics(batch, timing_raw):
|
||||
response_info = _compute_response_info(batch)
|
||||
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
|
||||
num_response_tokens = torch.sum(response_info['response_length']).item()
|
||||
num_overall_tokens = num_prompt_tokens + num_response_tokens
|
||||
|
||||
num_tokens_of_section = {
|
||||
'gen': num_response_tokens,
|
||||
**{
|
||||
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor', 'rollout']
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
**{
|
||||
f'timing_s/{name}': value for name, value in timing_raw.items()
|
||||
},
|
||||
**{
|
||||
f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
|
||||
)) & set(timing_raw.keys())
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _timer(name: str, timing_raw: Dict[str, float]):
|
||||
with Timer(name=name, logger=None) as timer:
|
||||
yield
|
||||
timing_raw[name] = timer.last
|
||||
|
||||
|
||||
class RayPPOTrainer(object):
|
||||
"""
|
||||
Note that this trainer runs on the driver process on a single CPU/GPU node.
|
||||
"""
|
||||
|
||||
# TODO: support each role have individual ray_worker_group_cls,
|
||||
# i.e., support different backend of different role
|
||||
def __init__(self,
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
resource_pool_manager: ResourcePoolManager,
|
||||
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
|
||||
reward_fn=None,
|
||||
val_reward_fn=None):
|
||||
|
||||
# assert torch.cuda.is_available(), 'cuda must be available on driver'
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.config = config
|
||||
self.reward_fn = reward_fn
|
||||
self.val_reward_fn = val_reward_fn
|
||||
|
||||
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
|
||||
assert self.hybrid_engine, 'Currently, only support hybrid engine'
|
||||
|
||||
if self.hybrid_engine:
|
||||
assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}'
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
||||
self.use_rm = Role.RewardModel in role_worker_mapping
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
|
||||
# define KL control
|
||||
if self.use_reference_policy:
|
||||
if config.algorithm.kl_ctrl.type == 'fixed':
|
||||
self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef)
|
||||
elif config.algorithm.kl_ctrl.type == 'adaptive':
|
||||
assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
|
||||
self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef,
|
||||
target_kl=config.algorithm.kl_ctrl.target_kl,
|
||||
horizon=config.algorithm.kl_ctrl.horizon)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.)
|
||||
|
||||
self._create_dataloader()
|
||||
self._init_logger()
|
||||
|
||||
def _init_logger(self):
|
||||
from verl.utils.tracking import Tracking
|
||||
self.logger = Tracking(project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
default_backend=self.config.trainer.logger,
|
||||
config=OmegaConf.to_container(self.config, resolve=True))
|
||||
|
||||
def _create_dataloader(self):
|
||||
from torch.utils.data import DataLoader
|
||||
# TODO: we have to make sure the batch size is divisible by the dp size
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
||||
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
filter_prompts=True,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation='error')
|
||||
if self.config.data.train_data_num is not None:
|
||||
if self.config.data.train_data_num > len(self.train_dataset.dataframe):
|
||||
print(f"[WARNING] training dataset size is smaller than desired size. Using the dataset as the original size {len(self.train_dataset.dataframe)}")
|
||||
else:
|
||||
self.train_dataset.dataframe = self.train_dataset.dataframe.sample(self.config.data.train_data_num, random_state=42)
|
||||
print(f"filtered training dataset size: {len(self.train_dataset.dataframe)}")
|
||||
|
||||
self.train_dataloader = DataLoader(dataset=self.train_dataset,
|
||||
batch_size=self.config.data.train_batch_size,
|
||||
shuffle=self.config.data.shuffle_train_dataloader,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
filter_prompts=True,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation='error')
|
||||
if self.config.data.val_data_num is not None:
|
||||
if self.config.data.val_data_num > len(self.val_dataset.dataframe):
|
||||
print(f"[WARNING] validation dataset size is smaller than desired size. Using the dataset as the original size {len(self.val_dataset.dataframe)}")
|
||||
else:
|
||||
self.val_dataset.dataframe = self.val_dataset.dataframe.sample(self.config.data.val_data_num, random_state=42)
|
||||
print(f"filtered validation dataset size: {len(self.val_dataset.dataframe)}")
|
||||
|
||||
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
||||
batch_size=self.config.data.val_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
print(f'Size of train dataloader: {len(self.train_dataloader)}')
|
||||
print(f'Size of val dataloader: {len(self.val_dataloader)}')
|
||||
|
||||
assert len(self.train_dataloader) >= 1
|
||||
assert len(self.val_dataloader) >= 1
|
||||
|
||||
# inject total_training_steps to actor/critic optim_config. This is hacky.
|
||||
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
||||
|
||||
if self.config.trainer.total_training_steps is not None:
|
||||
total_training_steps = self.config.trainer.total_training_steps
|
||||
|
||||
self.total_training_steps = total_training_steps
|
||||
print(f'Total training steps: {self.total_training_steps}')
|
||||
|
||||
OmegaConf.set_struct(self.config, True)
|
||||
with open_dict(self.config):
|
||||
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
|
||||
self.config.critic.optim.total_training_steps = total_training_steps
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
The training loop of PPO with global metric computation.
|
||||
Accumulates metrics across all batches before computing final statistics.
|
||||
"""
|
||||
import torch
|
||||
reward_tensor_lst = []
|
||||
data_source_lst = []
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
max_turns=self.config.max_turns,
|
||||
max_start_length=self.config.data.max_start_length,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
max_response_length=self.config.data.max_response_length,
|
||||
max_obs_length=self.config.data.max_obs_length,
|
||||
num_gpus=self.config.trainer.n_gpus_per_node,
|
||||
no_think_rl=self.config.algorithm.no_think_rl,
|
||||
search_url = self.config.retriever.url,
|
||||
topk = self.config.retriever.topk,
|
||||
)
|
||||
|
||||
# Agent config preparation
|
||||
generation_manager = LLMGenerationManager(
|
||||
tokenizer=self.tokenizer,
|
||||
actor_rollout_wg=self.actor_rollout_wg,
|
||||
config=gen_config,
|
||||
is_validation = True,
|
||||
)
|
||||
|
||||
if not self.config.do_search:
|
||||
for test_data in self.val_dataloader:
|
||||
test_batch = DataProto.from_single_dict(test_data)
|
||||
|
||||
# we only do validation on rule-based rm
|
||||
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
|
||||
return {}
|
||||
|
||||
test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
|
||||
test_gen_batch.meta_info = {
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
'recompute_log_prob': False,
|
||||
'do_sample': False,
|
||||
'validate': True,
|
||||
}
|
||||
|
||||
# pad to be divisible by dp_size
|
||||
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
|
||||
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
|
||||
# unpad
|
||||
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
|
||||
print('validation generation end')
|
||||
|
||||
test_batch = test_batch.union(test_output_gen_batch)
|
||||
|
||||
# evaluate using reward_function
|
||||
# for certain reward function (e.g. sandbox), the generation can overlap with reward
|
||||
reward_tensor = self.val_reward_fn(test_batch)
|
||||
|
||||
reward_tensor_lst.append(reward_tensor)
|
||||
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
|
||||
else:
|
||||
for batch_dict in self.val_dataloader:
|
||||
timing_raw = {}
|
||||
test_batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
# test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True)
|
||||
|
||||
test_gen_batch = test_batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
|
||||
test_gen_batch.meta_info = {
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
'recompute_log_prob': False,
|
||||
'do_sample': False,
|
||||
'validate': True,
|
||||
}
|
||||
with _timer('step', timing_raw):
|
||||
first_input_ids = test_gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone()
|
||||
with _timer('gen', timing_raw):
|
||||
generation_manager.timing_raw = timing_raw
|
||||
final_gen_batch_output = generation_manager.run_llm_loop(
|
||||
gen_batch=test_gen_batch,
|
||||
initial_input_ids=first_input_ids,
|
||||
)
|
||||
|
||||
test_batch = test_batch.union(final_gen_batch_output)
|
||||
|
||||
for key in test_batch.batch.keys():
|
||||
test_batch.batch[key] = test_batch.batch[key].long()
|
||||
|
||||
# evaluate using reward_function
|
||||
# for certain reward function (e.g. sandbox), the generation can overlap with reward
|
||||
try:
|
||||
reward_tensor = self.val_reward_fn(test_batch)
|
||||
except:
|
||||
print(test_batch)
|
||||
exit()
|
||||
|
||||
reward_tensor_lst.append(reward_tensor)
|
||||
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
|
||||
|
||||
reward_tensor = torch.cat([rw.sum(-1) for rw in reward_tensor_lst], dim=0).cpu() # (batch_size,)
|
||||
# reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
|
||||
data_sources = np.concatenate(data_source_lst, axis=0)
|
||||
# evaluate test_score based on data source
|
||||
data_source_reward = {}
|
||||
for i in range(reward_tensor.shape[0]):
|
||||
data_source = data_sources[i]
|
||||
if data_source not in data_source_reward:
|
||||
data_source_reward[data_source] = []
|
||||
data_source_reward[data_source].append(reward_tensor[i].item())
|
||||
|
||||
metric_dict = {}
|
||||
for data_source, rewards in data_source_reward.items():
|
||||
metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)
|
||||
|
||||
return metric_dict
|
||||
|
||||
|
||||
def init_workers(self):
|
||||
"""Init resource pool and worker group"""
|
||||
self.resource_pool_manager.create_resource_pool()
|
||||
|
||||
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
|
||||
|
||||
# create actor and rollout
|
||||
if self.hybrid_engine:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
|
||||
actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role='actor_rollout')
|
||||
self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# create critic
|
||||
if self.config.algorithm.adv_estimator == 'gae':
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
||||
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
|
||||
self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls
|
||||
self.use_critic = True
|
||||
|
||||
elif self.config.algorithm.adv_estimator == 'grpo':
|
||||
self.use_critic = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# create reference policy if needed
|
||||
if self.use_reference_policy:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
|
||||
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role='ref')
|
||||
self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
|
||||
|
||||
# create a reward model if reward_fn is None
|
||||
if self.use_rm:
|
||||
# we create a RM here
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
||||
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
||||
self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls
|
||||
|
||||
# initialize WorkerGroup
|
||||
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
|
||||
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
|
||||
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
|
||||
all_wg = {}
|
||||
self.wg_dicts = []
|
||||
for resource_pool, class_dict in self.resource_pool_to_cls.items():
|
||||
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
||||
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
|
||||
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
||||
all_wg.update(spawn_wg)
|
||||
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
|
||||
self.wg_dicts.append(wg_dict)
|
||||
|
||||
if self.use_critic:
|
||||
self.critic_wg = all_wg['critic']
|
||||
self.critic_wg.init_model()
|
||||
|
||||
if self.use_reference_policy:
|
||||
self.ref_policy_wg = all_wg['ref']
|
||||
self.ref_policy_wg.init_model()
|
||||
|
||||
if self.use_rm:
|
||||
self.rm_wg = all_wg['rm']
|
||||
self.rm_wg.init_model()
|
||||
|
||||
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
|
||||
self.actor_rollout_wg = all_wg['actor_rollout']
|
||||
self.actor_rollout_wg.init_model()
|
||||
|
||||
def _save_checkpoint(self):
|
||||
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
|
||||
f'global_step_{self.global_steps}')
|
||||
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
|
||||
self.config.trainer.default_hdfs_dir, 'actor')
|
||||
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)
|
||||
|
||||
if self.use_critic:
|
||||
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
|
||||
f'global_step_{self.global_steps}')
|
||||
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
|
||||
self.config.trainer.default_hdfs_dir, 'critic')
|
||||
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)
|
||||
|
||||
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
|
||||
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
|
||||
attention_mask = batch.batch['attention_mask']
|
||||
batch_size = attention_mask.shape[0]
|
||||
global_seqlen_lst = attention_mask.view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
|
||||
world_size = self.actor_rollout_wg.world_size
|
||||
global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst,
|
||||
k_partitions=world_size,
|
||||
equal_size=True)
|
||||
# reorder based on index. The data will be automatically equally partitioned by dispatch function
|
||||
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
|
||||
batch.reorder(global_idx)
|
||||
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst,
|
||||
partitions=global_partition_lst,
|
||||
prefix=logging_prefix)
|
||||
metrics.update(global_balance_stats)
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
|
||||
The light-weight advantage computation is done on the driver process.
|
||||
"""
|
||||
|
||||
logger = self.logger
|
||||
self.global_steps = 0
|
||||
# perform validation before training
|
||||
# currently, we only support validation using the reward_function.
|
||||
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
|
||||
val_metrics = self._validate()
|
||||
pprint(f'Initial validation metrics: {val_metrics}')
|
||||
logger.log(data=val_metrics, step=self.global_steps)
|
||||
if self.config.trainer.get('val_only', False):
|
||||
return
|
||||
|
||||
# we start from step 1
|
||||
self.global_steps += 1
|
||||
|
||||
# Agent config preparation
|
||||
gen_config = GenerationConfig(
|
||||
max_turns=self.config.max_turns,
|
||||
max_start_length=self.config.data.max_start_length,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
max_response_length=self.config.data.max_response_length,
|
||||
max_obs_length=self.config.data.max_obs_length,
|
||||
num_gpus=self.config.trainer.n_gpus_per_node,
|
||||
no_think_rl=self.config.algorithm.no_think_rl,
|
||||
search_url = self.config.retriever.url,
|
||||
topk = self.config.retriever.topk,
|
||||
)
|
||||
|
||||
generation_manager = LLMGenerationManager(
|
||||
tokenizer=self.tokenizer,
|
||||
actor_rollout_wg=self.actor_rollout_wg,
|
||||
config=gen_config,
|
||||
)
|
||||
|
||||
# start training loop
|
||||
for epoch in range(self.config.trainer.total_epochs):
|
||||
for batch_dict in self.train_dataloader:
|
||||
print(f'epoch {epoch}, step {self.global_steps}')
|
||||
metrics = {}
|
||||
timing_raw = {}
|
||||
|
||||
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True)
|
||||
|
||||
# pop those keys for generation
|
||||
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
|
||||
|
||||
####################
|
||||
# original code here
|
||||
|
||||
with _timer('step', timing_raw):
|
||||
if not self.config.do_search:
|
||||
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
||||
|
||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
dtype=object)
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
batch = batch.union(gen_batch_output)
|
||||
|
||||
####################
|
||||
# Below is aLL about agents - the "LLM + forloop"
|
||||
####################
|
||||
# with _timer('step', timing_raw):
|
||||
else:
|
||||
first_input_ids = gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone().long()
|
||||
|
||||
with _timer('gen', timing_raw):
|
||||
generation_manager.timing_raw = timing_raw
|
||||
final_gen_batch_output = generation_manager.run_llm_loop(
|
||||
gen_batch=gen_batch,
|
||||
initial_input_ids=first_input_ids,
|
||||
)
|
||||
|
||||
# final_gen_batch_output.batch.apply(lambda x: x.long(), inplace=True)
|
||||
for key in final_gen_batch_output.batch.keys():
|
||||
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||
except:
|
||||
print('############### here ###################')
|
||||
print(final_gen_batch_output)
|
||||
|
||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
dtype=object)
|
||||
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
batch = batch.union(final_gen_batch_output)
|
||||
|
||||
####################
|
||||
####################
|
||||
|
||||
# balance the number of valid tokens on each dp rank.
|
||||
# Note that this breaks the order of data inside the batch.
|
||||
# Please take care when you implement group based adv computation such as GRPO and rloo
|
||||
self._balance_batch(batch, metrics=metrics)
|
||||
|
||||
# compute global_valid tokens
|
||||
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
|
||||
|
||||
# batch.batch.apply(lambda x, key: x.long() if key != "old_log_probs" else x, inplace=True, key=True)
|
||||
for key in batch.batch.keys():
|
||||
if key != 'old_log_probs':
|
||||
batch.batch[key] = batch.batch[key].long()
|
||||
|
||||
if self.use_reference_policy:
|
||||
# compute reference log_prob
|
||||
with _timer('ref', timing_raw):
|
||||
try:
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
batch = batch.union(ref_log_prob)
|
||||
except:
|
||||
print('################## herehere ################')
|
||||
print(batch)
|
||||
|
||||
# compute values
|
||||
if self.use_critic:
|
||||
with _timer('values', timing_raw):
|
||||
values = self.critic_wg.compute_values(batch)
|
||||
batch = batch.union(values)
|
||||
|
||||
with _timer('adv', timing_raw):
|
||||
# compute scores. Support both model and function-based.
|
||||
# We first compute the scores using reward model. Then, we call reward_fn to combine
|
||||
# the results from reward model and rule-based results.
|
||||
if self.use_rm:
|
||||
# we first compute reward model score
|
||||
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
||||
batch = batch.union(reward_tensor)
|
||||
|
||||
# we combine with rule-based rm
|
||||
reward_tensor = self.reward_fn(batch)
|
||||
batch.batch['token_level_scores'] = reward_tensor
|
||||
|
||||
# compute rewards. apply_kl_penalty if available
|
||||
if not self.config.actor_rollout_ref.actor.use_kl_loss:
|
||||
batch, kl_metrics = apply_kl_penalty(batch,
|
||||
kl_ctrl=self.kl_ctrl,
|
||||
kl_penalty=self.config.algorithm.kl_penalty)
|
||||
metrics.update(kl_metrics)
|
||||
else:
|
||||
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
batch = compute_advantage(batch,
|
||||
adv_estimator=self.config.algorithm.adv_estimator,
|
||||
gamma=self.config.algorithm.gamma,
|
||||
lam=self.config.algorithm.lam,
|
||||
num_repeat=self.config.actor_rollout_ref.rollout.n)
|
||||
|
||||
# update critic
|
||||
if self.use_critic:
|
||||
with _timer('update_critic', timing_raw):
|
||||
critic_output = self.critic_wg.update_critic(batch)
|
||||
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
|
||||
metrics.update(critic_output_metrics)
|
||||
|
||||
# implement critic warmup
|
||||
if self.config.trainer.critic_warmup <= self.global_steps:
|
||||
# update actor
|
||||
with _timer('update_actor', timing_raw):
|
||||
if self.config.do_search and self.config.actor_rollout_ref.actor.state_masking:
|
||||
batch, metrics = self._create_loss_mask(batch, metrics)
|
||||
actor_output = self.actor_rollout_wg.update_actor(batch)
|
||||
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
|
||||
metrics.update(actor_output_metrics)
|
||||
|
||||
# validate
|
||||
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
|
||||
self.global_steps % self.config.trainer.test_freq == 0:
|
||||
with _timer('testing', timing_raw):
|
||||
val_metrics: dict = self._validate()
|
||||
metrics.update(val_metrics)
|
||||
|
||||
if self.config.trainer.save_freq > 0 and \
|
||||
self.global_steps % self.config.trainer.save_freq == 0:
|
||||
with _timer('save_checkpoint', timing_raw):
|
||||
self._save_checkpoint()
|
||||
|
||||
# collect metrics
|
||||
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
|
||||
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
||||
|
||||
# TODO: make a canonical logger that supports various backend
|
||||
logger.log(data=metrics, step=self.global_steps)
|
||||
|
||||
self.global_steps += 1
|
||||
|
||||
if self.global_steps >= self.total_training_steps:
|
||||
|
||||
# perform validation after training
|
||||
if self.val_reward_fn is not None:
|
||||
val_metrics = self._validate()
|
||||
pprint(f'Final validation metrics: {val_metrics}')
|
||||
logger.log(data=val_metrics, step=self.global_steps)
|
||||
return
|
||||
|
||||
def _create_loss_mask(self, batch, metrics):
|
||||
"""Create loss mask for state tokens."""
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
# Initialize state mask
|
||||
state_mask = torch.ones_like(response_mask)
|
||||
|
||||
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
# Find all pairs of start and end marker positions
|
||||
start_marker = self.config.algorithm.state_masking.start_state_marker
|
||||
end_marker = self.config.algorithm.state_masking.end_state_marker
|
||||
|
||||
# Get all start and end positions
|
||||
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
|
||||
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
|
||||
|
||||
# Convert character positions to token positions
|
||||
for start, end in zip(start_positions, end_positions):
|
||||
prefix_to_start = response[:start]
|
||||
state_section = response[start:end]
|
||||
|
||||
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
|
||||
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
|
||||
|
||||
start_token_pos = len(start_tokens)
|
||||
end_token_pos = start_token_pos + len(state_tokens)
|
||||
|
||||
state_mask[i, start_token_pos:end_token_pos] = 0
|
||||
|
||||
loss_mask = state_mask * response_mask
|
||||
batch.batch['loss_mask'] = loss_mask
|
||||
|
||||
# # Debug print
|
||||
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
|
||||
# response_ids = batch.batch['responses'][0]
|
||||
# unmasked_ids = response_ids[loss_mask[0] == 0]
|
||||
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
|
||||
|
||||
# masked_ids = response_ids[loss_mask[0] == 1]
|
||||
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 1]
|
||||
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 0]
|
||||
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
metrics.update({
|
||||
'state_tokens/total': loss_mask.sum().item(),
|
||||
'state_tokens/coverage': (loss_mask.sum() / response_mask.sum()).item(),
|
||||
})
|
||||
|
||||
return batch, metrics
|
||||
5
verl/trainer/runtime_env.yaml
Normal file
5
verl/trainer/runtime_env.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
working_dir: ./
|
||||
excludes: ["/.git/"]
|
||||
env_vars:
|
||||
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
|
||||
VLLM_ATTENTION_BACKEND: "XFORMERS"
|
||||
Reference in New Issue
Block a user