Initial commit
This commit is contained in:
18
verl/workers/critic/__init__.py
Normal file
18
verl/workers/critic/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BasePPOCritic
|
||||
from .dp_critic import DataParallelPPOCritic
|
||||
|
||||
__all__ = ["BasePPOCritic", "DataParallelPPOCritic"]
|
||||
40
verl/workers/critic/base.py
Normal file
40
verl/workers/critic/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Base class for a critic
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
__all__ = ['BasePPOCritic']
|
||||
|
||||
|
||||
class BasePPOCritic(ABC):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def compute_values(self, data: DataProto) -> torch.Tensor:
|
||||
"""Compute values"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_critic(self, data: DataProto):
|
||||
"""Update the critic"""
|
||||
pass
|
||||
204
verl/workers/critic/dp_critic.py
Normal file
204
verl/workers/critic/dp_critic.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Implement a multiprocess PPOCritic
|
||||
"""
|
||||
import itertools
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn, optim
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.critic import BasePPOCritic
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
|
||||
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
|
||||
|
||||
__all__ = ['DataParallelPPOCritic']
|
||||
|
||||
|
||||
class DataParallelPPOCritic(BasePPOCritic):
|
||||
|
||||
def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):
|
||||
super().__init__(config=config)
|
||||
self.critic_module = critic_module
|
||||
self.critic_optimizer = critic_optimizer
|
||||
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
|
||||
print(f'Critic use_remove_padding={self.use_remove_padding}')
|
||||
|
||||
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
|
||||
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
|
||||
|
||||
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
|
||||
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
response_length = micro_batch['responses'].size(-1)
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# pad and slice the inputs if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
|
||||
position_ids_rmpad, \
|
||||
sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.critic_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values_rmpad = output.logits
|
||||
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
values_rmpad = gather_outpus_and_unpad(values_rmpad,
|
||||
gather_dim=0,
|
||||
unpad_dim=0,
|
||||
padding_size=pad_size)
|
||||
|
||||
# pad it back
|
||||
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
|
||||
values = values[:, -response_length - 1:-1]
|
||||
else:
|
||||
output = self.critic_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values = output.logits
|
||||
values = values[:, -response_length - 1:-1].squeeze(-1)
|
||||
return values
|
||||
|
||||
def _optimizer_step(self):
|
||||
assert self.config.grad_clip is not None
|
||||
|
||||
if isinstance(self.critic_module, FSDP):
|
||||
grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)
|
||||
self.critic_optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
def compute_values(self, data: DataProto) -> torch.Tensor:
|
||||
self.critic_module.eval()
|
||||
micro_batch_size = data.meta_info['micro_batch_size']
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
|
||||
|
||||
if use_dynamic_bsz:
|
||||
# split using dynamic bsz
|
||||
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
|
||||
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = batch.split(micro_batch_size)
|
||||
|
||||
values_lst = []
|
||||
for micro_batch in micro_batches:
|
||||
with torch.no_grad():
|
||||
values = self._forward_micro_batch(micro_batch)
|
||||
values_lst.append(values)
|
||||
values = torch.concat(values_lst, dim=0)
|
||||
responses = data.batch['responses']
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_length = responses.size(1)
|
||||
values = values * attention_mask[:, -response_length - 1:-1]
|
||||
|
||||
if use_dynamic_bsz:
|
||||
indices = list(itertools.chain.from_iterable(indices))
|
||||
assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}"
|
||||
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
||||
values = values[revert_indices]
|
||||
|
||||
return values
|
||||
|
||||
def update_critic(self, data: DataProto):
|
||||
# make sure we are in training mode
|
||||
self.critic_module.train()
|
||||
metrics = {}
|
||||
|
||||
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
# Split to make minibatch iterator for updating the actor
|
||||
# See PPO paper for details. https://arxiv.org/abs/1707.06347
|
||||
dataloader = batch.split(self.config.ppo_mini_batch_size)
|
||||
|
||||
for batch_idx, data in enumerate(dataloader):
|
||||
# split batch into micro_batches
|
||||
mini_batch = data
|
||||
if self.config.use_dynamic_bsz:
|
||||
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
|
||||
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
|
||||
for data in micro_batches:
|
||||
data = data.cuda() # critic device is cpu when using offload
|
||||
input_ids = data['input_ids']
|
||||
responses = data['responses']
|
||||
attention_mask = data['attention_mask']
|
||||
position_ids = data['position_ids']
|
||||
values = data['values']
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length - 1:-1]
|
||||
|
||||
vpreds = self._forward_micro_batch(data)
|
||||
|
||||
# assert not torch.any(torch.isnan(vpreds)).item()
|
||||
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
cliprange_value=self.config.cliprange_value)
|
||||
loss = vf_loss / self.gradient_accumulation
|
||||
loss.backward()
|
||||
|
||||
data = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
}
|
||||
|
||||
append_to_dict(metrics, data)
|
||||
|
||||
grad_norm = self._optimizer_step()
|
||||
data = {'critic/grad_norm': grad_norm.detach().item()}
|
||||
append_to_dict(metrics, data)
|
||||
self.critic_optimizer.zero_grad()
|
||||
return metrics
|
||||
229
verl/workers/critic/megatron_critic.py
Normal file
229
verl/workers/critic/megatron_critic.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Implement a multiprocess PPOCritic
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.critic import BasePPOCritic
|
||||
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
|
||||
from megatron.optimizer import DistributedOptimizer
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core.pipeline_parallel import get_forward_backward_func
|
||||
|
||||
|
||||
class MegatronPPOCritic(BasePPOCritic):
|
||||
|
||||
def __init__(self, config, model_config, megatron_config, critic_module: nn.ModuleList,
|
||||
critic_optimizer: DistributedOptimizer, critic_optimizer_config: OptimizerConfig):
|
||||
super().__init__(config=config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.megatron_config = megatron_config
|
||||
|
||||
self.critic_module = critic_module
|
||||
self.critic_optimizer = critic_optimizer
|
||||
self.critic_optimizer_config = critic_optimizer_config
|
||||
|
||||
# we create a separate nametuple for optimizer step so that global args won't affect it.
|
||||
self.optimizer_step_args = OmegaConf.create({
|
||||
'skip_grad': None,
|
||||
'overlap_dp_param_comm': False,
|
||||
'overlap_dp_grad_comm': False,
|
||||
'gradient_accumulation_steps': 1,
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'DDP_impl': 'local',
|
||||
'layernorm_allreduce_bucket_threshold': 0,
|
||||
'pipeline_model_parallel_split_rank': None,
|
||||
'reduce_grads_use_alltoall': False
|
||||
})
|
||||
|
||||
if self.config.kl_ctrl.type == 'fixed':
|
||||
self.kl_ctrl = core_algos.FixedKLController(kl_coef=self.config.kl_ctrl.kl_coef)
|
||||
elif self.config.kl_ctrl.type == 'adaptive':
|
||||
assert self.config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {self.config.kl_ctrl.horizon}'
|
||||
self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=self.config.kl_ctrl.kl_coef,
|
||||
target_kl=self.config.kl_ctrl.target_kl,
|
||||
horizon=self.config.kl_ctrl.horizon)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_values(self, data: DataProto) -> DataProto:
|
||||
# data.batch = data.batch.to(self.critic_module.module.device)
|
||||
responses = data.batch['responses']
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_length = responses.size(1)
|
||||
with torch.no_grad():
|
||||
output = self.forward_backward_batch(data=data, forward_only=True)
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# only on last rank. It should be on every tp rank
|
||||
values = torch.cat([o['vpreds'] for o in output], dim=0) # (bs, seq_size, vocal_size)
|
||||
values = values.to(torch.float32)
|
||||
else:
|
||||
values = torch.empty_like(attention_mask, dtype=torch.float32)
|
||||
|
||||
# each tp ranks should contain the same value
|
||||
values = values * attention_mask
|
||||
values = values[:, -response_length - 1:-1]
|
||||
values = values.contiguous()
|
||||
|
||||
# sync among pp ranks
|
||||
torch.distributed.broadcast(tensor=values,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return values
|
||||
|
||||
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
|
||||
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
|
||||
data = data.select(batch_keys=select_keys)
|
||||
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
|
||||
epochs=self.config.ppo_epochs,
|
||||
dataloader_kwargs={'shuffle': self.config.shuffle})
|
||||
|
||||
def forward_backward_batch(self, data: DataProto, forward_only=False):
|
||||
# broadcast from last pp rank to all other pp ranks
|
||||
data.batch = data.batch.contiguous()
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
# split into micro-batches
|
||||
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
|
||||
batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size)
|
||||
n_micro_batch = len(batches)
|
||||
seq_len = batches[0]['input_ids'].shape[1]
|
||||
|
||||
# compute input shapes for pp stages
|
||||
input_shapes = compute_transformers_input_shapes(
|
||||
batches,
|
||||
meta_info={
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'hidden_size': self.model_config.hidden_size
|
||||
})
|
||||
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
|
||||
def loss_func(output, data, meta_info):
|
||||
if forward_only:
|
||||
return 1.0, {'vpreds': output.logits}
|
||||
|
||||
responses = data['responses']
|
||||
attention_mask = data['attention_mask']
|
||||
values = data['values']
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length:]
|
||||
|
||||
cliprange_value = self.config.cliprange_value
|
||||
|
||||
vpreds = output.logits # (bs, sequence_length)
|
||||
vpreds = vpreds[:, -response_length - 1:-1]
|
||||
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
cliprange_value=cliprange_value)
|
||||
stats = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
}
|
||||
|
||||
return vf_loss, stats
|
||||
|
||||
def forward_step(batch_iter, model):
|
||||
batch = next(batch_iter)
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
position_ids = batch['position_ids']
|
||||
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
return output, partial(loss_func, data=batch, meta_info={})
|
||||
|
||||
# batch should be a list of batches inside micro-batches
|
||||
batch_generator = make_batch_generator(batches, vpp_size=len(self.critic_module))
|
||||
|
||||
# TODO: we may use the new schedule instead
|
||||
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.critic_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
input_shapes=input_shapes, # must set for flash-attn sequence packing
|
||||
seq_length=self.config.ppo_micro_batch_size * seq_len, # no use when input_shapes was set
|
||||
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
|
||||
micro_batch_size=1, # no use when input_shapes was set
|
||||
forward_only=forward_only,
|
||||
)
|
||||
else:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.critic_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
seq_length=self.config.ppo_micro_batch_size * seq_len, # in use for pp = 1
|
||||
hidden_size=self.model_config.hidden_size, # in use for pp = 1
|
||||
micro_batch_size=1, # in use for pp = 1
|
||||
forward_only=forward_only,
|
||||
)
|
||||
# loss_reduces contains the stats returned from loss_func
|
||||
return losses_reduced
|
||||
|
||||
def update_critic(self, dataloader: Iterable[DataProto]):
|
||||
metrics = {}
|
||||
|
||||
for data in dataloader:
|
||||
# data = data.batch.to(self.critic_module.device)
|
||||
self.critic_optimizer.zero_grad()
|
||||
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
|
||||
for chunk in self.critic_module:
|
||||
chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer))
|
||||
|
||||
metric_micro_batch = self.forward_backward_batch(data)
|
||||
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step(
|
||||
self.megatron_config, self.megatron_config.timers)
|
||||
if update_successful:
|
||||
# allgather already execute in optimizer.step in new megatron
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for metric in metric_micro_batch:
|
||||
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
return metrics
|
||||
Reference in New Issue
Block a user