Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

13
verl/workers/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,18 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOActor
from .dp_actor import DataParallelPPOActor
__all__ = ["BasePPOActor", "DataParallelPPOActor"]

View File

@@ -0,0 +1,66 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The base class for Actor
"""
from abc import ABC, abstractmethod
from typing import Iterable, Dict
from verl import DataProto
import torch
__all__ = ['BasePPOActor']
class BasePPOActor(ABC):
def __init__(self, config):
"""The base class for PPO actor
Args:
config (DictConfig): a config passed to the PPOActor. We expect the type to be
DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
"""
super().__init__()
self.config = config
@abstractmethod
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute logits given a batch of data.
Args:
data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
```attention_mask``` and ```position_ids```.
Returns:
DataProto: a DataProto containing the key ```log_probs```
"""
pass
@abstractmethod
def update_policy(self, data: DataProto) -> Dict:
"""Update the policy with an iterator of DataProto
Args:
data (DataProto): an iterator over the DataProto that returns by
```make_minibatch_iterator```
Returns:
Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
such as ```loss```, ```grad_norm```, etc,.
"""
pass

View File

@@ -0,0 +1,290 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single Process Actor
"""
import itertools
from typing import Iterable, Tuple
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['DataParallelPPOActor']
class DataParallelPPOActor(BasePPOActor):
def __init__(
self,
config,
actor_module: nn.Module,
actor_optimizer: torch.optim.Optimizer = None,
):
"""When optimizer is None, it is Reference Policy"""
super().__init__(config)
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.use_remove_padding = self.config.get('use_remove_padding', False)
print(f'Actor use_remove_padding={self.use_remove_padding}')
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# compute entropy
entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad back to (bsz, seqlen)
full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
# only return response part:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
else: # not using rmpad and no ulysses sp
output = self.actor_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
logits = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
return entropy, log_probs
def _optimizer_step(self):
assert self.config.grad_clip is not None
if isinstance(self.actor_module, FSDP):
grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
self.actor_optimizer.step()
return grad_norm
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
torch.Tensor: the log_prob tensor
"""
# set to eval
self.actor_module.eval()
micro_batch_size = data.meta_info['micro_batch_size']
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
batch = data.select(batch_keys=select_keys).batch
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
log_probs_lst = []
for micro_batch in micro_batches:
with torch.no_grad():
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
return log_probs
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.state_masking:
select_keys.append('loss_mask')
if self.config.use_kl_loss:
select_keys.append('ref_log_prob')
batch = data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
self.actor_optimizer.zero_grad()
for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
if self.config.state_masking:
response_mask = data['loss_mask']
old_log_prob = data['old_log_probs']
advantages = data['advantages']
clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef
loss = policy_loss / self.gradient_accumulation
loss.backward()
data = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
}
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.actor_optimizer.zero_grad()
return metrics

View File

@@ -0,0 +1,368 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Megatron Actor.
In megatron actor, the differences are:
1. We only make minibatch
Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer
"""
from functools import partial
from typing import Iterable, Dict
import torch
from torch import nn
import torch.distributed
# from megatron import get_args
from megatron.optimizer import DistributedOptimizer
from verl.utils.megatron.optimizer_config import OptimizerConfig
from megatron.core import parallel_state as mpu
from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel import get_forward_backward_func
# from megatron.core.optimizer import DistributedOptimizer
from omegaconf import OmegaConf
from verl.utils.megatron.tensor_parallel import vocab_parallel_compute_entropy_loss, vocab_parallel_log_probs_from_logits
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
__all__ = ['MegatronPPOActor']
class MegatronPPOActor(BasePPOActor):
def __init__(self, config, model_config, megatron_config: ModelParallelConfig, actor_module: nn.ModuleList,
actor_optimizer: DistributedOptimizer, actor_optimizer_config: OptimizerConfig):
"""MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.
Args:
config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain
``ppo_micro_batch_size``: minibatch size when updating ppo.
``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.
``ppo_epochs``: number of epochs to update the actor using the batch data.
``shuffle``: whether to shuffle the data after each ppo epoch.
``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.
``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.
model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
``model_config.hidden_size``
megatron_config (OmegaConf): megatron configuration. It must contains
``sequence_parallel_enabled``: whether the sequence parallel is enabled.
``param_dtype``: the dtype of the parameters.
``virtual_pipeline_model_parallel_size``: virtual pipeline model parallel size. a.k.a number of chunks in each pp stage.
actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage.
each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for more details.
The actor module has some constraints to follow in order to use the updating logics implemented here
1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an
optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
(https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).
2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
of the hidden state is [total_nnz // tp, 1, hidden_size].
actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements
zero1 optimizer that shards the optimizer state across dp ranks.
>>> def megatron_actor_model_provider(pre_process, post_process):
>>> vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
>>> parallel_model = ParallelMistralForCausalLMRmPadPP(config=actor_model_config,
>>> megatron_config=megatron_config,
>>> pre_process=pre_process,
>>> post_process=post_process).cuda()
>>> return parallel_model
>>> from megatron.training import get_model
>>> from megatron.optimizer import get_megatron_optimizer
>>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)
>>> actor_module = nn.ModuleList(actor_module)
>>> actor_optimizer = get_megatron_optimizer(actor_module)
>>> actor = MegatronPPOActor(config=config,
>>> model_config=actor_model_config,
>>> megatron_config=megatron_config,
>>> actor_module=actor_module,
>>> actor_optimizer=actor_optimizer)
"""
super().__init__(config)
self.model_config = model_config
self.megatron_config = megatron_config
# self.megatron_args = get_args()
self.actor_module = actor_module
self.actor_optimizer: DistributedOptimizer = actor_optimizer
self.actor_optimizer_config = actor_optimizer_config
self.optimizer_step_args = OmegaConf.create({
'skip_grad': None,
'overlap_dp_param_comm': False,
'overlap_dp_grad_comm': False,
'gradient_accumulation_steps': 1,
'sequence_parallel': self.megatron_config.sequence_parallel,
'DDP_impl': 'local',
'layernorm_allreduce_bucket_threshold': 0,
'pipeline_model_parallel_split_rank': None,
'reduce_grads_use_alltoall': False
})
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
DataProto: torch.Tensor: the log_prob tensor
"""
data.batch = data.batch.contiguous()
def compute_logprobs_fn(output, data):
response = data['responses']
response_length = response.size(1)
logits = output['logits']
logits = logits[:, -response_length - 1:-1]
log_probs = vocab_parallel_log_probs_from_logits(logits, response)
return {'log_probs': log_probs}
# We make recompute_old_log_prob by default here.
# TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside
recompute_old_log_prob = self.config.get('recompute_old_log_prob', True)
if recompute_old_log_prob or 'old_log_probs' not in data.batch.keys():
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
batch = data.select(batch_keys=select_keys).batch
input_ids = batch['input_ids']
batch_size = input_ids.size(0)
response = batch['responses']
response_length = response.size(1)
with torch.no_grad():
output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# only on last rank. It should be on every tp rank
log_probs = torch.cat([o['log_probs'] for o in output], dim=0) # (bs, seq_size)
log_probs = log_probs.to(torch.float32)
else:
log_probs = torch.empty(size=(batch_size, response_length),
dtype=torch.float32,
device=input_ids.device)
# broadcast across pp ranks
torch.distributed.broadcast(tensor=log_probs,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False)
# add empty cache after each compute
torch.cuda.empty_cache()
return log_probs
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""Make minibatch iterator for updating the actor
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length``
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:]
``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses.
``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses.
See PPO paper for details. https://arxiv.org/abs/1707.06347
Returns:
"""
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})
def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None):
"""
We assume:
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
"""
# broadcast from last pp rank to all other pp ranks
# TODO: actually, we just need to control the sampling order.
broadcast_dict_tensor(data.batch,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group())
# split into micro-batches
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
if data.meta_info.get('micro_batch_size', None) is not None:
batch_size = data.meta_info['micro_batch_size']
else:
batch_size = self.config.ppo_micro_batch_size
batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
# compute input shapes for pp stages
input_shapes = compute_transformers_input_shapes(
batches,
meta_info={
'sequence_parallel': self.megatron_config.sequence_parallel,
'hidden_size': self.model_config.hidden_size
})
n_micro_batch = len(batches)
seq_len = batches[0]['input_ids'].shape[1]
forward_backward_func = get_forward_backward_func()
def loss_func(output, data, meta_info):
if forward_only:
if post_process_fn is None:
return 1.0, {'logits': output.logits}
else:
return 1.0, post_process_fn(output, data)
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
old_log_prob = data['old_log_probs']
advantages = data['advantages']
clip_ratio = meta_info['clip_ratio']
entropy_coeff = meta_info['entropy_coeff']
# compute policy loss
logits = output.logits
logits = logits[:, -response_length - 1:-1]
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask)
policy_loss = pg_loss - entropy_loss * entropy_coeff
# return loss and stats
stats = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item()
}
return policy_loss, stats
def forward_step(batch_iter, model):
batch = next(batch_iter)
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
position_ids = batch['position_ids']
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
if forward_only:
meta_info = None
else:
meta_info = {'clip_ratio': self.config.clip_ratio, 'entropy_coeff': self.config.entropy_coeff}
return output, partial(loss_func, data=batch, meta_info=meta_info)
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=batch_size * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
seq_length=batch_size * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1
forward_only=forward_only,
)
# loss_reduces contains the stats returned from loss_func
return losses_reduced
def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:
"""Update the policy with an iterator of DataProto
Args:
dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
The keys of each data batch is described in the make_minibatch_iterator.
Returns:
Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
and users have to combine the output in each dp rank manually.
"""
metrics = {}
for data in dataloader:
# data = data.batch.to(self.actor_module.device)
self.actor_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer(zero_buffer=(not self.actor_optimizer_config.use_distributed_optimizer))
metric_micro_batch = self.forward_backward_batch(data)
for metric in metric_micro_batch:
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step(
self.megatron_config, self.megatron_config.timers)
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError
for metric in metric_micro_batch:
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
# add empty cache after each compute
torch.cuda.empty_cache()
return metrics

View File

@@ -0,0 +1,18 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOCritic
from .dp_critic import DataParallelPPOCritic
__all__ = ["BasePPOCritic", "DataParallelPPOCritic"]

View File

@@ -0,0 +1,40 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for a critic
"""
from abc import ABC, abstractmethod
import torch
from verl import DataProto
__all__ = ['BasePPOCritic']
class BasePPOCritic(ABC):
def __init__(self, config):
super().__init__()
self.config = config
@abstractmethod
def compute_values(self, data: DataProto) -> torch.Tensor:
"""Compute values"""
pass
@abstractmethod
def update_critic(self, data: DataProto):
"""Update the critic"""
pass

View File

@@ -0,0 +1,204 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPOCritic
"""
import itertools
from typing import Iterable
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['DataParallelPPOCritic']
class DataParallelPPOCritic(BasePPOCritic):
def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):
super().__init__(config=config)
self.critic_module = critic_module
self.critic_optimizer = critic_optimizer
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
print(f'Critic use_remove_padding={self.use_remove_padding}')
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
def _forward_micro_batch(self, micro_batch):
response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
batch, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outpus_and_unpad(values_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1:-1]
else:
output = self.critic_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
values = output.logits
values = values[:, -response_length - 1:-1].squeeze(-1)
return values
def _optimizer_step(self):
assert self.config.grad_clip is not None
if isinstance(self.critic_module, FSDP):
grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)
self.critic_optimizer.step()
return grad_norm
def compute_values(self, data: DataProto) -> torch.Tensor:
self.critic_module.eval()
micro_batch_size = data.meta_info['micro_batch_size']
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
values_lst = []
for micro_batch in micro_batches:
with torch.no_grad():
values = self._forward_micro_batch(micro_batch)
values_lst.append(values)
values = torch.concat(values_lst, dim=0)
responses = data.batch['responses']
attention_mask = data.batch['attention_mask']
response_length = responses.size(1)
values = values * attention_mask[:, -response_length - 1:-1]
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
values = values[revert_indices]
return values
def update_critic(self, data: DataProto):
# make sure we are in training mode
self.critic_module.train()
metrics = {}
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
batch = data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
self.critic_optimizer.zero_grad()
for data in micro_batches:
data = data.cuda() # critic device is cpu when using offload
input_ids = data['input_ids']
responses = data['responses']
attention_mask = data['attention_mask']
position_ids = data['position_ids']
values = data['values']
returns = data['returns']
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1:-1]
vpreds = self._forward_micro_batch(data)
# assert not torch.any(torch.isnan(vpreds)).item()
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value)
loss = vf_loss / self.gradient_accumulation
loss.backward()
data = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {'critic/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.critic_optimizer.zero_grad()
return metrics

View File

@@ -0,0 +1,229 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPOCritic
"""
from functools import partial
from typing import Iterable
import torch
import torch.distributed
from omegaconf import OmegaConf
from torch import nn
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron.optimizer_config import OptimizerConfig
from megatron.optimizer import DistributedOptimizer
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
class MegatronPPOCritic(BasePPOCritic):
def __init__(self, config, model_config, megatron_config, critic_module: nn.ModuleList,
critic_optimizer: DistributedOptimizer, critic_optimizer_config: OptimizerConfig):
super().__init__(config=config)
self.model_config = model_config
self.megatron_config = megatron_config
self.critic_module = critic_module
self.critic_optimizer = critic_optimizer
self.critic_optimizer_config = critic_optimizer_config
# we create a separate nametuple for optimizer step so that global args won't affect it.
self.optimizer_step_args = OmegaConf.create({
'skip_grad': None,
'overlap_dp_param_comm': False,
'overlap_dp_grad_comm': False,
'gradient_accumulation_steps': 1,
'sequence_parallel': self.megatron_config.sequence_parallel,
'DDP_impl': 'local',
'layernorm_allreduce_bucket_threshold': 0,
'pipeline_model_parallel_split_rank': None,
'reduce_grads_use_alltoall': False
})
if self.config.kl_ctrl.type == 'fixed':
self.kl_ctrl = core_algos.FixedKLController(kl_coef=self.config.kl_ctrl.kl_coef)
elif self.config.kl_ctrl.type == 'adaptive':
assert self.config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {self.config.kl_ctrl.horizon}'
self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=self.config.kl_ctrl.kl_coef,
target_kl=self.config.kl_ctrl.target_kl,
horizon=self.config.kl_ctrl.horizon)
else:
raise NotImplementedError
def compute_values(self, data: DataProto) -> DataProto:
# data.batch = data.batch.to(self.critic_module.module.device)
responses = data.batch['responses']
attention_mask = data.batch['attention_mask']
response_length = responses.size(1)
with torch.no_grad():
output = self.forward_backward_batch(data=data, forward_only=True)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# only on last rank. It should be on every tp rank
values = torch.cat([o['vpreds'] for o in output], dim=0) # (bs, seq_size, vocal_size)
values = values.to(torch.float32)
else:
values = torch.empty_like(attention_mask, dtype=torch.float32)
# each tp ranks should contain the same value
values = values * attention_mask
values = values[:, -response_length - 1:-1]
values = values.contiguous()
# sync among pp ranks
torch.distributed.broadcast(tensor=values,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group())
# add empty cache after each compute
torch.cuda.empty_cache()
return values
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})
def forward_backward_batch(self, data: DataProto, forward_only=False):
# broadcast from last pp rank to all other pp ranks
data.batch = data.batch.contiguous()
broadcast_dict_tensor(data.batch,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group())
# split into micro-batches
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size)
n_micro_batch = len(batches)
seq_len = batches[0]['input_ids'].shape[1]
# compute input shapes for pp stages
input_shapes = compute_transformers_input_shapes(
batches,
meta_info={
'sequence_parallel': self.megatron_config.sequence_parallel,
'hidden_size': self.model_config.hidden_size
})
forward_backward_func = get_forward_backward_func()
def loss_func(output, data, meta_info):
if forward_only:
return 1.0, {'vpreds': output.logits}
responses = data['responses']
attention_mask = data['attention_mask']
values = data['values']
returns = data['returns']
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length:]
cliprange_value = self.config.cliprange_value
vpreds = output.logits # (bs, sequence_length)
vpreds = vpreds[:, -response_length - 1:-1]
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=cliprange_value)
stats = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
}
return vf_loss, stats
def forward_step(batch_iter, model):
batch = next(batch_iter)
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
position_ids = batch['position_ids']
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
return output, partial(loss_func, data=batch, meta_info={})
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(batches, vpp_size=len(self.critic_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.critic_module,
num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=self.config.ppo_micro_batch_size * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.critic_module,
num_microbatches=n_micro_batch,
seq_length=self.config.ppo_micro_batch_size * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1
forward_only=forward_only,
)
# loss_reduces contains the stats returned from loss_func
return losses_reduced
def update_critic(self, dataloader: Iterable[DataProto]):
metrics = {}
for data in dataloader:
# data = data.batch.to(self.critic_module.device)
self.critic_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.critic_module:
chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer))
metric_micro_batch = self.forward_backward_batch(data)
update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step(
self.megatron_config, self.megatron_config.timers)
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError
for metric in metric_micro_batch:
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
# add empty cache after each compute
torch.cuda.empty_cache()
return metrics

1054
verl/workers/fsdp_workers.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,735 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import os
import logging
import ray
import torch
import torch.distributed
import torch.nn as nn
from omegaconf import DictConfig
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.sharding_manager import AllGatherPPModel
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.model import load_megatron_model_weights
from verl.utils.megatron_utils import init_model_parallel_config
from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad
from verl.utils import hf_tokenizer
from megatron.core import parallel_state as mpu
from megatron.core import ModelParallelConfig
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
def set_random_seed(seed):
import torch
import numpy as np
import random
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.device_count() > 0:
from megatron.core import tensor_parallel
tensor_parallel.model_parallel_cuda_manual_seed(seed)
# FIXME: torch cumsum not support deterministic (used in vllm sampler),
# https://github.com/pytorch/pytorch/issues/89492
# torch.use_deterministic_algorithms(True, warn_only=True)
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
class ActorRolloutRefWorker(MegatronWorker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
# Therefore, we only require one distribute initialization.
# To utilize different parallel startegy in different models:
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
if not torch.distributed.is_initialized():
rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
if self.config.actor.megatron.sequence_parallel:
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,
pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
set_random_seed(seed=self.config.actor.megatron.seed)
self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
# TODO(sgm): Currently, we only support reference model param offload
# will support other offload later
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
# normalize config
if self._is_actor and self._is_rollout:
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self._is_offload_param = self.config.actor.get('param_offload', False)
self._is_offload_grad = self.config.actor.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False)
elif self._is_ref:
self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self._is_offload_param = self.config.ref.get('param_offload', False)
def _build_model_optimizer(self,
model_path,
megatron_config: ModelParallelConfig,
optim_config,
override_model_config,
enable_gradient_checkpointing=False):
from verl.utils.megatron.optimizer import get_megatron_optimizer
from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path)
self.tokenizer = hf_tokenizer(local_path)
# Step 2: get the actor_model_config
actor_model_config = AutoConfig.from_pretrained(local_path)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {actor_model_config}')
def megatron_actor_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_model_from_config
# vpp is not supported yet because it will hang for some reason. Need debugging
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = get_parallel_model_from_config(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
value=False)
parallel_model.cuda()
return parallel_model
# Step 3: initialize the megatron model
if self._is_actor and self._is_rollout:
# Initialize the 3D HybridEngine
hybrid_engine = AllGatherPPModel(model_provider=megatron_actor_model_provider)
# Fetch the model at current rank
actor_module = hybrid_engine.this_rank_models
if isinstance(actor_module, nn.ModuleList):
actor_module = [actor_module[0]]
if self.config.actor.load_weight:
load_megatron_model_weights(self.config,
actor_model_config,
actor_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)
if self.rank == 0:
print_model_size(actor_module[0])
log_gpu_memory_usage('After AllGatherPPModel init', logger=logger)
elif self._is_ref:
print(f'self.config.ref.load_weight: {self.config.ref.load_weight}')
ref_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=False)
# ref_module = nn.ModuleList(ref_module)
if self.config.ref.load_weight: # should align with the actor:
assert self.config.actor.load_weight == self.config.ref.load_weight
print(f'load ref weight start')
load_megatron_model_weights(self.config,
actor_model_config,
ref_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)
log_gpu_memory_usage('After ref module init', logger=logger)
return ref_module, actor_model_config
# TODO: add more optimizer args into config
if self._is_actor:
optim_config = init_megatron_optim_config(optim_config)
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
else:
optim_config = None
actor_optimizer = None
log_gpu_memory_usage('After actor optimizer init', logger=logger)
return actor_module, hybrid_engine, actor_optimizer, actor_model_config, optim_config
def _build_rollout(self):
if self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager import MegatronVLLMShardingManager
from verl.utils.model import normalize_pp_vpp_params
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
# we will reorganize their weight format when resharding from actor to rollout.
layer_name_mapping = {
"qkv_layer_name":
self.config.rollout.layer_name_map.get("qkv_layer_name", "qkv"),
"gate_proj_layer_name":
self.config.rollout.layer_name_map.get("gate_proj_layer_name", "linear_fc1.weight"),
}
# reshard the weight partition from actor to rollout to initialize the rollout class
# create a new cuda space for parameters not in this pp rank
self.hybrid_engine.load_params_to_cuda()
# broadcast the parameters from pp rank to other ranks
self.hybrid_engine.allgather_params()
# obtain name to parameters in pp/vpp
params = self.hybrid_engine.get_all_params()
# update the param name for the
params = normalize_pp_vpp_params(params=params,
num_hidden_layers=self.actor_model_config.num_hidden_layers,
layer_name='layers')
rollout = vLLMRollout(actor_module=params,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
train_tp=mpu.get_tensor_model_parallel_world_size())
log_gpu_memory_usage('After building vllm rollout', logger=logger)
# perform weight resharding between actor and rollout
sharding_manager = MegatronVLLMShardingManager(module=self.hybrid_engine,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
layer_name_mapping=layer_name_mapping)
log_gpu_memory_usage('After building sharding manager', logger=logger)
else:
NotImplementedError('Only vllmRollout is supported with Megatron now')
return rollout, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
from omegaconf import OmegaConf
from verl.utils.torch_dtypes import PrecisionType
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
torch_dtype = torch.bfloat16
megatron_config = OmegaConf.create({
'sequence_parallel': self.config.actor.megatron.get('sequence_parallel', True),
'param_dtype': PrecisionType.to_str(torch_dtype),
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
})
megatron_config = init_model_parallel_config(megatron_config)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
else:
optim_config = None
self.actor_module, self.hybrid_engine, self.actor_optimizer, \
self.actor_model_config, self.actor_optim_config = self._build_model_optimizer(
model_path=self.config.model.path,
megatron_config=megatron_config,
optim_config=optim_config,
override_model_config=override_model_config,
)
if self._is_actor:
self.actor = MegatronPPOActor(config=self.config.actor,
model_config=self.actor_model_config,
megatron_config=megatron_config,
actor_module=self.actor_module,
actor_optimizer=self.actor_optimizer,
actor_optimizer_config=self.actor_optim_config)
if self._is_rollout:
self.rollout, self.sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module, self.ref_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
megatron_config=megatron_config,
optim_config=None,
override_model_config=override_model_config,
)
self.ref_policy = MegatronPPOActor(config=self.config.ref,
model_config=self.ref_model_config,
megatron_config=megatron_config,
actor_module=self.ref_module,
actor_optimizer=None,
actor_optimizer_config=None)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
assert self._is_actor
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before update policy', logger=logger)
dataloader = self.actor.make_minibatch_iterator(data=data)
metrics = self.actor.update_policy(dataloader=dataloader)
log_gpu_memory_usage('After update policy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
torch.cuda.empty_cache()
return output
# @register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)
# def compute_log_prob(self, data: DataProto) -> DataProto:
# assert self._is_rollout
# output = self.actor.compute_log_prob(data=data)
# output = DataProto.from_dict(tensors={'old_log_probs': output})
# torch.cuda.empty_cache()
# return output
@register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)
def generate_sequences(self, prompts: DataProto):
assert self._is_rollout
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
prompts = self.sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.sharding_manager.postprocess_data(output)
validate = prompts.meta_info.get('validate', False)
if self._is_actor and not validate:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['temperature'] = self.config.rollout.temperature
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
data = data.to('cuda')
assert self._is_ref
if self._is_offload_param:
load_megatron_param_and_grad(self.ref_module, torch.cuda.current_device(), self._is_offload_grad)
micro_batch_size = self.config.rollout.log_prob_micro_batch_size
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['temperature'] = self.config.rollout.temperature
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'ref_log_prob': output})
output = output.to('cpu')
if self._is_offload_param:
offload_megatron_param_and_grad(self.ref_module, self._is_offload_grad)
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, checkpoint_path):
pass
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_pretrained_model(self, checkpoint_path):
pass
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, checkpoint_path):
assert self._is_actor
pass
class CriticWorker(MegatronWorker):
def __init__(self, config):
super().__init__()
self.config = config
# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
# Therefore, we only require one distribute initialization.
# To utilize different parallel startegy in different models:
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
if not torch.distributed.is_initialized():
rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
if self.config.megatron.sequence_parallel:
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,
pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
set_random_seed(seed=self.config.megatron.seed)
# normalize config
self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
# TODO(sgm): support critic model offload
def _build_critic_model_optimizer(self,
model_path,
megatron_config: ModelParallelConfig,
optim_config,
override_model_config,
enable_gradient_checkpointing=False):
from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config
from verl.utils.megatron.optimizer import get_megatron_optimizer
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path)
self.tokenizer = hf_tokenizer(local_path)
# Step 2: get the actor_model_config
critic_model_config = AutoConfig.from_pretrained(local_path)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(critic_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {critic_model_config}')
def megatron_critic_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_model_from_config
# TODO: support vpp here
# vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = get_parallel_model_from_config(config=critic_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
value=True)
parallel_model.cuda()
return parallel_model
# Step 3: initialize the megatron model
critic_module = get_model(model_provider_func=megatron_critic_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
# note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).
# but here, we do not use pp (vpp) yet. For simplicity, we remove the list
# critic_module = nn.ModuleList(critic_module)
if self.config.load_weight:
load_megatron_model_weights(self.config,
critic_model_config,
critic_module,
params_dtype=megatron_config.params_dtype,
is_value_model=True)
if self.rank == 0:
print_model_size(critic_module[0])
# TODO: add more optimizer args into config
optim_config = init_megatron_optim_config(optim_config)
critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config)
torch.cuda.empty_cache()
return critic_module, critic_optimizer, critic_model_config, optim_config
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# create critic
from omegaconf import OmegaConf
from verl.utils.torch_dtypes import PrecisionType
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
torch_dtype = torch.bfloat16
megatron_config = OmegaConf.create({
'sequence_parallel': self.config.megatron.get('sequence_parallel', True),
'param_dtype': PrecisionType.to_str(torch_dtype),
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
})
megatron_config = init_model_parallel_config(megatron_config)
critic_module, critic_optimizer, critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer(
model_path=self.config.model.path,
megatron_config=megatron_config,
optim_config=self.config.optim,
override_model_config=override_model_config)
self.critic = MegatronPPOCritic(config=self.config,
model_config=critic_model_config,
megatron_config=megatron_config,
critic_module=critic_module,
critic_optimizer=critic_optimizer,
critic_optimizer_config=critic_optimizer_config)
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
data = data.to('cuda')
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={'values': values})
output = output.to('cpu')
return output
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
data = data.to('cuda')
dataloader = self.critic.make_minibatch_iterator(data)
metrics = self.critic.update_critic(dataloader=dataloader)
output = DataProto(batch=None, meta_info={'metrics': metrics})
output = output.to('cpu')
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, checkpoint_path):
pass
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, checkpoint_path):
pass
class RewardModelWorker(MegatronWorker):
"""
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
"""
def __init__(self, config):
super().__init__()
self.config = config
# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
# Therefore, we only require one distribute initialization.
# To utilize different parallel startegy in different models:
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
if not torch.distributed.is_initialized():
rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
if self.config.megatron.sequence_parallel:
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,
pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
set_random_seed(seed=self.config.megatron.seed)
# normalize config
self.config.micro_batch_size //= mpu.get_data_parallel_world_size()
def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, override_model_config):
from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config
from verl.utils.megatron_utils import get_model
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path)
self.tokenizer = hf_tokenizer(local_path)
# Step 2: get the actor_model_config
rm_model_config = AutoConfig.from_pretrained(local_path)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(rm_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {rm_model_config}')
def megatron_rm_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_model_from_config
# vpp is not supported yet because it will hang for some reason. Need debugging
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = get_parallel_model_from_config(config=rm_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
value=True)
parallel_model.cuda()
return parallel_model
# Step 3: initialize the megatron model
reward_model = get_model(model_provider_func=megatron_rm_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=False)
# note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).
# but here, we do not use pp (vpp) yet. For simplicity, we remove the list
# reward_model = nn.ModuleList(reward_model)
if self.config.load_weight:
load_megatron_model_weights(self.config,
rm_model_config,
reward_model,
params_dtype=megatron_config.params_dtype,
is_value_model=True)
# TODO: add more optimizer args into config
torch.cuda.empty_cache()
return reward_model, rm_model_config
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# create critic
from omegaconf import OmegaConf
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoTokenizer
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
sft_tokenizer_local_path = copy_local_path_from_hdfs(self.config.model.input_tokenizer)
sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path)
rm_tokenizer_path = self.config.model.get('rm_tokenizer', None)
rm_tokenizer = None
if rm_tokenizer_path is not None:
rm_tokenizer_local_path = copy_local_path_from_hdfs(rm_tokenizer_path)
rm_tokenizer = hf_tokenizer(rm_tokenizer_local_path)
torch_dtype = torch.bfloat16
megatron_config = OmegaConf.create({
'sequence_parallel': self.config.megatron.get('sequence_parallel', True),
'param_dtype': PrecisionType.to_str(torch_dtype),
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
})
megatron_config = init_model_parallel_config(megatron_config)
reward_model_module, reward_model_config = self._build_rm_model(
model_path=self.config.model.path,
megatron_config=megatron_config,
override_model_config=override_model_config,
)
# FIXME(sgm): reward model param offload is implemented in MegatronRewardModel
# should be implemented in workers
self.rm = MegatronRewardModel(config=self.config,
reward_model_module=reward_model_module,
model_config=reward_model_config,
megatron_config=megatron_config,
sft_tokenizer=sft_tokenizer,
rm_tokenizer=rm_tokenizer)
# TODO: reward model use itself tokenizer instead of sft tokenizer
# the input_ids, responses, attention_mask and position_ids may be different!
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data.batch = data.batch.cuda()
output = self.rm.compute_reward(data)
output = output.to('cpu')
return output

View File

@@ -0,0 +1,383 @@
import os
import torch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
# from search_r1.env.search.retrieval import get_retriever
import json
import os
import warnings
from typing import List, Dict
import functools
from tqdm import tqdm
from multiprocessing import Pool
import faiss
import torch
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoModel
import argparse
import datasets
def load_corpus(corpus_path: str):
corpus = datasets.load_dataset(
'json',
data_files=corpus_path,
split="train",
num_proc=4)
return corpus
def read_jsonl(file_path):
data = []
with open(file_path, "r") as f:
readin = f.readlines()
for line in readin:
data.append(json.loads(line))
return data
def load_docs(corpus, doc_idxs):
results = [corpus[int(idx)] for idx in doc_idxs]
return results
def load_model(
model_path: str,
use_fp16: bool = False
):
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model.cuda()
if use_fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tokenizer
def pooling(
pooler_output,
last_hidden_state,
attention_mask = None,
pooling_method = "mean"
):
if pooling_method == "mean":
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling_method == "cls":
return last_hidden_state[:, 0]
elif pooling_method == "pooler":
return pooler_output
else:
raise NotImplementedError("Pooling method not implemented!")
class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
self.model_path = model_path
self.pooling_method = pooling_method
self.max_length = max_length
self.use_fp16 = use_fp16
self.model, self.tokenizer = load_model(model_path=model_path,
use_fp16=use_fp16)
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]
if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
inputs = self.tokenizer(query_list,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros(
(inputs['input_ids'].shape[0], 1), dtype=torch.long
).to(inputs['input_ids'].device)
output = self.model(
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
)
query_emb = output.last_hidden_state[:, 0, :]
else:
output = self.model(**inputs, return_dict=True)
query_emb = pooling(output.pooler_output,
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb
class BaseRetriever:
"""Base object for all retrievers."""
def __init__(self, config):
self.config = config
self.retrieval_method = config.retrieval_method
self.topk = config.retrieval_topk
self.index_path = config.index_path
self.corpus_path = config.corpus_path
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
r"""Retrieve topk relevant documents in corpus.
Return:
list: contains information related to the document, including:
contents: used for building index
title: (if provided)
text: (if provided)
"""
pass
def _batch_search(self, query_list, num, return_score):
pass
def search(self, *args, **kwargs):
return self._search(*args, **kwargs)
def batch_search(self, *args, **kwargs):
return self._batch_search(*args, **kwargs)
class BM25Retriever(BaseRetriever):
r"""BM25 retriever based on pre-built pyserini index."""
def __init__(self, config):
super().__init__(config)
raise NotImplementedError
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(self.index_path)
self.contain_doc = self._check_contain_doc()
if not self.contain_doc:
self.corpus = load_corpus(self.corpus_path)
self.max_process_num = 8
def _check_contain_doc(self):
r"""Check if the index contains document content
"""
return self.searcher.doc(0).raw() is not None
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
if len(hits) < 1:
if return_score:
return [],[]
else:
return []
scores = [hit.score for hit in hits]
if len(hits) < num:
warnings.warn('Not enough documents retrieved!')
else:
hits = hits[:num]
if self.contain_doc:
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
results = [{'title': content.split("\n")[0].strip("\""),
'text': "\n".join(content.split("\n")[1:]),
'contents': content} for content in all_contents]
else:
results = load_docs(self.corpus, [hit.docid for hit in hits])
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list, num: int = None, return_score = False):
# TODO: modify batch method
results = []
scores = []
for query in query_list:
item_result, item_score = self._search(query, num,True)
results.append(item_result)
scores.append(item_score)
if return_score:
return results, scores
else:
return results
class DenseRetriever(BaseRetriever):
r"""Dense retriever based on pre-built faiss index."""
def __init__(self, config: dict, index):
super().__init__(config)
self.index = index
# self.index = faiss.read_index(self.index_path)
# if config.faiss_gpu:
# co = faiss.GpuMultipleClonerOptions()
# co.useFloat16 = True
# co.shard = True
# self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
# # self.index = faiss.index_cpu_to_all_gpus(self.index)
self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config.retrieval_model_path,
pooling_method = config.retrieval_pooling_method,
max_length = config.retrieval_query_max_length,
use_fp16 = config.retrieval_use_fp16
)
self.topk = config.retrieval_topk
self.batch_size = self.config.retrieval_batch_size
def _search(self, query: str, num: int = None, return_score = False):
raise NotImplementedError
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)
idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
if return_score:
return results, scores
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
num = self.topk
batch_size = self.batch_size
results = []
scores = []
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
query_batch = query_list[start_idx:start_idx + batch_size]
# from time import time
# a = time()
batch_emb = self.encoder.encode(query_batch)
# b = time()
# print(f'################### encode time {b-a} #####################')
batch_scores, batch_idxs = ray.get(self.index.batch_search.remote(batch_emb, k=num))
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()
# print(f'################### search time {time()-b} #####################')
# exit()
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
scores.extend(batch_scores)
results.extend(batch_results)
if return_score:
return results, scores
else:
return results
def get_retriever(config, index):
r"""Automatically select retriever class based on config's retrieval method
Args:
config (dict): configuration with 'retrieval_method' key
Returns:
Retriever: retriever instance
"""
if config.retrieval_method == "bm25":
raise NotImplementedError
return BM25Retriever(config)
else:
return DenseRetriever(config, index)
class RetrieveWorker(Worker):
"""Environment worker that handles GPU-based environment operations."""
def __init__(self, config, faiss_server):
super().__init__()
config.index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index') if config.retrieval_method != 'bm25' else os.path.join(config.index_path, 'bm25')
self.config = config # Initialize environment later
self.faiss_server = faiss_server
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
self.retriever = get_retriever(self.config, self.faiss_server)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ALL_TO_ALL)
def batch_search(self, queries):
return self.retriever.batch_search(queries)
import ray
import faiss
import torch
@ray.remote(num_gpus=8) # Allocate all GPUs
class FAISSIndexServer:
"""Ray Actor that loads and serves a shared FAISS index with FAISS GPU optimization."""
def __init__(self, config):
"""Initialize the FAISS index only once."""
print("[FAISSIndexServer] Loading FAISS index...")
self.config = config
self.index = self.load_index(config)
def load_index(self, config):
"""Loads the FAISS index into GPU memory with sharding."""
index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index')
index = faiss.read_index(index_path)
if self.config.faiss_gpu:
# Apply FAISS GPU settings
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True # Reduce memory footprint
co.shard = True # Distribute index across all GPUs
print("[FAISSIndexServer] Moving FAISS index to all GPUs with sharding enabled...")
index = faiss.index_cpu_to_all_gpus(index, co=co)
print("[FAISSIndexServer] FAISS index successfully moved to GPUs.")
return index
def batch_search(self, batch_emb, k):
"""Perform batch search on the FAISS index."""
print(f"[FAISSIndexServer] Received {len(batch_emb)} queries.")
return self.index.search(batch_emb, k) # Adjust 'k' as needed

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPORewardModel

View File

@@ -0,0 +1,45 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The base class for reward model
"""
from abc import ABC, abstractmethod
from verl import DataProto
class BasePPORewardModel(ABC):
def __init__(self, config):
self.config = config
@abstractmethod
def compute_reward(self, data: DataProto) -> DataProto:
"""Computing reward given input_ids. The transformers should output a tensor with shape
[batch_size, sequence_length], and the value at [EOS] mask should be gathered.
Args:
data: must contain keys "input_ids", "attention_mask" and "position_ids".
- input_ids: [batch_size, sequence_length]
- attention_mask: [batch_size, sequence_length]
- position_ids: [batch_size, sequence_length]
Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward.
Other position should have zero reward. Note that this may change in the future if we use
dense reward. So, we leave the interface for general case.
- reward: [batch_size, sequence_length].
"""
pass

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .reward_model import MegatronRewardModel

View File

@@ -0,0 +1,278 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Megatron Reward Model.
"""
from tensordict import TensorDict
from functools import partial
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits
import torch
import torch
import torch.distributed
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.torch_dtypes import PrecisionType
from verl.workers.reward_model.base import BasePPORewardModel
from verl.utils.megatron import sequence_parallel as sp_utils
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
class MegatronRewardModel(BasePPORewardModel):
def __init__(self,
config,
model_config,
reward_model_module: torch.nn.ModuleList,
megatron_config,
sft_tokenizer=None,
rm_tokenizer=None):
self.config = config
self.reward_model_module = reward_model_module
self.megatron_config = megatron_config
self.model_config = model_config
self.device = 'cuda'
self.sft_tokenizer = sft_tokenizer
self.rm_tokenizer = rm_tokenizer
self.use_different_tokenizer = rm_tokenizer is not None
if self.config.param_offload:
self.offload_params_to_cpu()
def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:
assert self.use_different_tokenizer, 're-encode need rm tokenizer not be None!'
# need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids
# 1. remove pad for each sequence
# 2. decode by sft_tokenizer, remove sft system prompts
# 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids
# 4. generate attention_mask and position_ids
input_ids = data.batch['input_ids'] # (bs, seq_len)
attention_mask = data.batch['attention_mask']
position_ids = data.batch['position_ids']
ori_values = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
ori_bs, ori_seqlen = input_ids.size(0), input_ids.size(1)
input_ids_for_rm = []
attention_mask_for_rm = []
position_ids_for_rm = []
print_decode = True
ori_seqlen = ori_seqlen + 128
for id, mask in zip(input_ids, attention_mask):
# 1. remove pad for each sequence
non_zero_indices = torch.nonzero(mask).view(-1)
begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item()
valid_id = id[begin_pos:end_pos + 1]
# 2. decode by sft_tokenizer, remove sft system prompts
decode_result = self.sft_tokenizer.decode(valid_id)
# workaround
decode_with_rm_chat = decode_result.replace("<|user|>\n", "[INST] ").replace(
"</s>\n<|assistant|>\n", " [/INST]").replace("</s> \n<|assistant|>\n", " [/INST]") + "</s>"
print(f"decode_with_rm_chat: {decode_with_rm_chat}")
if print_decode and torch.distributed.get_rank() == 0:
# only print first decode result
print(f'device {torch.cuda.current_device()}: sft decode result:\n{decode_result}\n \
\ndevice {torch.cuda.current_device()}: sft decode result with rm chat template:\n{decode_with_rm_chat}\n\n'
)
print_decode = False
# 3. encode by rm_tokenizer
rm_input_ids = self.rm_tokenizer(decode_with_rm_chat,
return_tensors='pt')['input_ids'][0].to(input_ids.device)
# 4. generate attention_mask and position_ids
rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device)
cur_seqlen = rm_input_ids.shape[-1]
# NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128)
if cur_seqlen > ori_seqlen:
print(f'warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}')
rm_input_ids = rm_input_ids[:ori_seqlen]
rm_attention_mask = rm_attention_mask[:ori_seqlen]
else:
# right padding
rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)
rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)
rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)
input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))
attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0))
position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0))
input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0)
attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0)
position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0)
# (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change
# NOTE(gh): need to replace into origin values after compute reward!
data.batch['input_ids'] = input_ids_for_rm
data.batch['attention_mask'] = attention_mask_for_rm
data.batch['position_ids'] = position_ids_for_rm
return data, ori_values
@torch.no_grad()
def compute_reward(self, data: DataProto) -> DataProto:
if self.config.param_offload:
self.load_params_to_cuda()
if self.use_different_tokenizer:
data, ori_values = self.re_encode_by_rm_tokenizer(data)
input_ids = data.batch['input_ids'] # (bs, seq_len')
attention_mask = data.batch['attention_mask']
position_ids = data.batch['position_ids']
responses = data.batch['responses']
batch_size = responses.size(0)
response_length = responses.size(1)
with torch.no_grad():
output = self.forward_batch(data)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
logits = torch.cat([o['logits'] for o in output], dim=0)
else:
logits = torch.empty(
(input_ids.shape[0], input_ids.shape[1]),
dtype=torch.bfloat16, # TODO(sgm): check why is bfloat16
device=input_ids.device)
# broadcast across pp ranks
torch.distributed.broadcast(tensor=logits,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False)
# (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen')
token_level_rewards = logits
# find the last token reward
ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1) # (bs, 1)
rewards = torch.gather(token_level_rewards, dim=1, index=ends) # (bs, 1)
if self.use_different_tokenizer:
data.batch.update(ori_values)
input_ids = ori_values['input_ids']
attention_mask = ori_values['attention_mask']
position_ids = ori_values['position_ids']
token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen)
# assign last valid token reward to ori position
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bs,)
eos_mask = torch.zeros_like(attention_mask)
eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.
token_level_rewards = token_level_rewards * eos_mask
token_level_rewards = token_level_rewards[:, -response_length:]
if self.config.param_offload:
self.offload_params_to_cpu()
else:
# add empty cache after each compute
torch.cuda.empty_cache()
batch = TensorDict({'rm_scores': token_level_rewards}, batch_size=input_ids.shape[0])
return DataProto(batch=batch)
def forward_batch(self, data: DataProto):
"""
We assume:
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
"""
# broadcast from last pp rank to all other pp ranks
# TODO: actually, we just need to control the sampling order.
data.batch = data.batch.contiguous()
broadcast_dict_tensor(data.batch,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group())
# split into micro-batches
if self.config is not None and 'ppo_micro_batch_size' in self.config:
infer_batch_size = self.config.ppo_micro_batch_size
else:
infer_batch_size = data.batch.batch_size[0]
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
batches = split_dict_tensor_into_batches(data.batch, batch_size=infer_batch_size)
n_micro_batch = len(batches)
seq_len = batches[0]['input_ids'].shape[1]
# compute input shapes for pp stages
input_shapes = compute_transformers_input_shapes(
batches,
meta_info={
'sequence_parallel': self.megatron_config.sequence_parallel,
'hidden_size': self.model_config.hidden_size
})
# compute input shapes for pp stages
forward_backward_func = get_forward_backward_func()
def loss_func(output):
return 1., {'logits': output.logits}
def forward_step(batch_iter, model):
batch = next(batch_iter)
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
position_ids = batch['position_ids']
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
return output, loss_func
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(batches, vpp_size=len(self.reward_model_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.reward_model_module,
num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=infer_batch_size * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=True,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.reward_model_module,
num_microbatches=n_micro_batch,
seq_length=infer_batch_size * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1
forward_only=True,
)
# loss_reduces contains the stats returned from loss_func
return losses_reduced
def offload_params_to_cpu(self):
if self.device == 'cuda':
for reward_model_module in self.reward_model_module:
for name, param in reward_model_module.named_parameters():
param.data = param.data.to('cpu', non_blocking=True)
self.device = 'cpu'
torch.cuda.empty_cache()
def load_params_to_cuda(self):
if self.device == 'cpu':
for reward_model_module in self.reward_model_module:
for name, param in reward_model_module.named_parameters():
param.data = param.data.to(torch.cuda.current_device(), non_blocking=True)
self.device = 'cuda'

View File

@@ -0,0 +1,19 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BaseRollout
from .naive import NaiveRollout
from .hf_rollout import HFRollout
__all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]

View File

@@ -0,0 +1,37 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Iterable, Union
from verl import DataProto
__all__ = ['BaseRollout']
class BaseRollout(ABC):
def __init__(self):
"""
Args:
dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
should handle when the training stops.
"""
super().__init__()
@abstractmethod
def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Generate sequences"""
pass

View File

@@ -0,0 +1,140 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout with huggingface models.
TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single GPU model.
Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model to perform generation.
"""
import contextlib
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask
from .base import BaseRollout
from transformers import GenerationConfig
__all__ = ['HFRollout']
class HFRollout(BaseRollout):
def __init__(self, module: nn.Module, config):
super().__init__()
self.config = config
self.module = module
def generate_sequences(self, prompts: DataProto) -> DataProto:
batch_size = prompts.batch.batch_size[0]
num_chunks = max(batch_size // self.config.get('micro_batch_size', batch_size), 1)
batch_prompts = prompts.chunk(chunks=num_chunks)
output = [self._generate_minibatch(p) for p in batch_prompts]
output = DataProto.concat(output)
return output
@torch.no_grad()
def _generate_minibatch(self, prompts: DataProto) -> DataProto:
idx = prompts.batch['input_ids'] # (bs, prompt_length)
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
pad_token_id = prompts.meta_info['pad_token_id']
batch_size = idx.size(0)
prompt_length = idx.size(1)
self.module.eval()
param_ctx = contextlib.nullcontext()
# make sampling args can be overriden by inputs
do_sample = prompts.meta_info.get('do_sample', self.config.do_sample)
response_length = prompts.meta_info.get('response_length', self.config.response_length)
top_p = prompts.meta_info.get('top_p', self.config.get('top_p', 1.0))
top_k = prompts.meta_info.get('top_k', self.config.get('top_k', 0))
if top_k is None:
top_k = 0
top_k = max(0, top_k) # to be compatible with vllm
temperature = prompts.meta_info.get('temperature', self.config.temperature)
generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)
if isinstance(self.module, FSDP):
# recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069
param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)
with param_ctx:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.module.generate(
input_ids=idx,
attention_mask=attention_mask,
do_sample=do_sample,
max_new_tokens=response_length,
# max_length=max_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=True)
# TODO: filter out the seq with no answers like ds-chat
seq = output.sequences
# huggingface generate will stop generating when all the batch reaches [EOS].
# We have to pad to response_length
sequence_length = prompt_length + self.config.response_length
delta_length = sequence_length - seq.shape[1]
if delta_length > 0:
delta_tokens = torch.ones(size=(batch_size, delta_length), device=seq.device, dtype=seq.dtype)
delta_tokens = pad_token_id * delta_tokens
seq = torch.cat((seq, delta_tokens), dim=1)
assert seq.shape[1] == sequence_length
prompt = seq[:, :prompt_length] # (bs, prompt_length)
response = seq[:, prompt_length:] # (bs, response_length)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
batch = TensorDict(
{
'prompts': prompt,
'responses': response,
'input_ids': seq,
'attention_mask': attention_mask,
'position_ids': position_ids
},
batch_size=batch_size)
# empty cache before compute old_log_prob
torch.cuda.empty_cache()
self.module.train()
return DataProto(batch=batch)

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .naive_rollout import NaiveRollout

View File

@@ -0,0 +1,119 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In single GPU rollout, the sequences are generated directly by sampling from the model.
The output will contain
1. output_ids
2. attention_masks (left padding)
3. eos_masks
4. log_probs
"""
from typing import Iterable, Union
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits
from ..base import BaseRollout
__all__ = ['NativeRollout']
class NaiveRollout(BaseRollout):
def __init__(self, module: nn.Module, config):
"""A naive rollout. It requires the module to be compatible with huggingface APIs. That is:
The module should define __call__ to receive input_ids, attention_mask and position_ids.
It outputs a structure that contains logits field.
Args:
module: module here follows huggingface APIs
config: DictConfig
"""
super().__init__()
self.config = config
self.module = module
@torch.no_grad()
def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Generate sequences"""
idx = prompts.batch['input_ids'] # (bs, prompt_length)
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
batch_size = idx.size(0)
prompt_length = idx.size(1)
self.module.eval()
prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
logits_lst = []
for _ in range(self.config.response_length):
# if the sequence context is growing too long we must crop it at block_size
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
idx_cond = idx
# forward the model to get the logits for the index in the sequence
# we use huggingface APIs here
output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)
logits = output.logits
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size)
# optionally crop the logits to only the top k options
if self.config.top_k is not None:
v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
if self.config.do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
prev_attention_mask.to(attention_mask.dtype)
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
logits_lst.append(logits)
logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size)
prompts = idx[:, :prompt_length] # (bs, prompt_length)
response = idx[:, prompt_length:] # (bs, response_length)
log_probs = logprobs_from_logits(logits=logits, labels=response)
batch = TensorDict(
{
'input_ids': prompts,
'responses': response,
'sequences': idx,
'old_log_probs': log_probs,
'attention_mask': attention_mask,
'position_ids': position_ids,
},
batch_size=batch_size)
self.module.train()
return DataProto(batch=batch)

View File

@@ -0,0 +1,162 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Union
__all__ = ['HybridEngineBaseTokenizer']
class HybridEngineBaseTokenizer(ABC):
"""the tokenizer property and function name should align with HF's to meet vllm requirement"""
@property
@abstractmethod
def vocab_size(self):
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
pass
@property
@abstractmethod
def pad_token_id(self):
"""
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
"""
pass
@property
@abstractmethod
def eos_token_id(self):
"""
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
set.
"""
pass
@property
@abstractmethod
def all_special_ids(self) -> List[int]:
"""
`List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
"""
pass
@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
"""
`List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).
Convert tokens of `tokenizers.AddedToken` type to string.
"""
pass
@abstractmethod
def encode(self, text):
"""
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
Args:
text (`str`, `List[str]` or `List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
`tokenize` method) or a list of integers.
text_pair (`str`, `List[str]` or `List[int]`, *optional*):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers.
"""
pass
@abstractmethod
def decode(
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces`.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
pass
@abstractmethod
def convert_ids_to_tokens(self,
ids: Union[int, List[int]],
skip_special_tokens: bool = False) -> Union[str, List[str]]:
"""
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
added tokens.
Args:
ids (`int` or `List[int]`):
The token id (or token ids) to convert to tokens.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
Returns:
`str` or `List[str]`: The decoded token(s).
"""
pass
@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
something we should change.
Returns:
`Dict[str, int]`: The added tokens.
"""
pass
@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
often want to remove sub-word tokenization artifacts at the same time.
Args:
tokens (`List[str]`): The token to join in a string.
Returns:
`str`: The joined tokens.
"""
pass
@property
def is_fast(self):
return False

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vllm_rollout import vLLMRollout

View File

@@ -0,0 +1,226 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from typing import List
from contextlib import contextmanager
from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams
# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
class vLLMRollout(BaseRollout):
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in vllm
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super().__init__()
self.config = config
assert not (not config.enforce_eager and config.free_cache_engine), \
"disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
if kwargs.get('train_tp', None) is not None:
# deployed with megatron
import os
os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
train_tp = kwargs.get('train_tp', None)
num_tp_per_train_tp = train_tp // tensor_parallel_size
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
num_tp_per_train_tp=num_tp_per_train_tp)
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
"model context length should be greater than total sequence length"
self.inference_engine = LLM(actor_module,
tokenizer=tokenizer,
model_hf_config=model_hf_config,
tensor_parallel_size=tensor_parallel_size,
dtype=config.dtype,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
skip_tokenizer_init=False,
max_model_len=config.prompt_length + config.response_length,
load_format=config.load_format)
# Offload vllm model to reduce peak memory usage
self.inference_engine.offload_model_weights()
kwargs = dict(
n=1,
logprobs=1, # can be set to 0 and let actor to recompute
max_tokens=config.response_length,
)
# we may detokenize the result all together later
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
kwargs['detokenize'] = False
# supporting adding any sampling params from the config file
for k in config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = config.get(k)
print(f"kwargs: {kwargs}")
self.sampling_params = SamplingParams(**kwargs)
self.pad_token_id = tokenizer.pad_token_id
@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params, key):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# rebuild vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.init_cache_engine()
idx = prompts.batch['input_ids'] # (bs, prompt_length)
# left-padded attention_mask
attention_mask = prompts.batch['attention_mask']
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
batch_size = idx.size(0)
idx_list = []
# parse idx from torch.Tensor to List[List[str]]
for i in range(batch_size):
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))
do_sample = prompts.meta_info.get('do_sample', True)
if not do_sample:
kwargs = {
'best_of': 1,
'top_p': 1.0,
'top_k': -1,
'min_p': 0.0,
'temperature': 0,
'n': 1 # if greedy, only 1 response
}
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)
# TODO(sgm): disable logprob when recompute_log_prob is enable
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
response = output[0].to(idx.device)
log_probs = output[1].to(idx.device)
if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
if self.config.n > 1 and do_sample:
idx = idx.repeat_interleave(self.config.n, dim=0)
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
batch_size = batch_size * self.config.n
seq = torch.cat([idx, response], dim=-1)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
'prompts': idx,
'responses': response,
'input_ids': seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'attention_mask': attention_mask,
'position_ids': position_ids
},
batch_size=batch_size)
# free vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.free_cache_engine()
return DataProto(batch=batch)

View File

@@ -0,0 +1,33 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
from .base import BaseShardingManager
from .fsdp_ulysses import FSDPUlyssesShardingManager
AllGatherPPModel = None
if is_megatron_core_available() and is_vllm_available():
from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
elif AllGatherPPModel is not None:
pass
else:
AllGatherPPModel = None
MegatronVLLMShardingManager = None
if is_vllm_available():
from .fsdp_vllm import FSDPVLLMShardingManager
else:
FSDPVLLMShardingManager = None

View File

@@ -0,0 +1,33 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sharding manager to implement HybridEngine
"""
from verl import DataProto
class BaseShardingManager:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
def preprocess_data(self, data: DataProto) -> DataProto:
return data
def postprocess_data(self, data: DataProto) -> DataProto:
return data

View File

@@ -0,0 +1,88 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
"""
from typing import Optional
from .base import BaseShardingManager
import random
from torch.distributed.device_mesh import DeviceMesh
from verl.utils.torch_functional import allgather_dict_tensors
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
import numpy as np
import torch
import torch.distributed
from verl import DataProto
class FSDPUlyssesShardingManager(BaseShardingManager):
"""
Sharding manager to support data resharding when using FSDP + Ulysses
"""
def __init__(self, device_mesh: DeviceMesh):
super().__init__()
self.device_mesh = device_mesh
self.seed_offset = 12345
def __enter__(self):
if self.device_mesh is not None:
# We have a global SP group
# so we have to change to use model-specific sp group
self.prev_sp_group = get_ulysses_sequence_parallel_group()
set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
# TODO: check how to set seed for each model
def __exit__(self, exc_type, exc_value, traceback):
# restore random states
if self.device_mesh is not None:
# revert to previous sp group
set_ulysses_sequence_parallel_group(self.prev_sp_group)
# TODO: check how to set seed for each model
def preprocess_data(self, data: DataProto) -> DataProto:
"""
AllGather data from sp region
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
In Ulysses, we need to make sure the same data is used across a SP group
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
group = self.device_mesh['sp'].get_group()
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(sp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
}
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""
Split the data to follow FSDP partition
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
sp_rank = self.device_mesh['sp'].get_local_rank()
data = data.chunk(chunks=sp_size)[sp_rank]
return data

View File

@@ -0,0 +1,133 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from torch.distributed.device_mesh import DeviceMesh
from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps
from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
from verl.utils.debug import log_gpu_memory_usage
from .base import BaseShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
class FSDPVLLMShardingManager(BaseShardingManager):
def __init__(self,
module: FSDP,
inference_engine: LLM,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.device_mesh = device_mesh
# Full params
self.full_params = full_params
if full_params:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig())
else:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh['dp'].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def __enter__(self):
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = 'hf' if self.full_params else 'dtensor'
self.inference_engine.sync_model_weights(params, load_format=load_format)
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
del params
torch.cuda.empty_cache()
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
# TODO: offload FSDP model weights
# self.module.cpu()
# torch.cuda.empty_cache()
# if torch.distributed.get_rank() == 0:
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger)
self.inference_engine.offload_model_weights()
log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger)
# self.module.to('cuda')
# if torch.distributed.get_rank() == 0:
# print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
self.module.train()
# add empty cache after each compute
torch.cuda.empty_cache()
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=vllm_ps.get_tensor_model_parallel_world_size(),
group=vllm_ps.get_tensor_model_parallel_group(),
dim=0)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp
broadcast_dict_tensor(data.batch,
src=vllm_ps.get_tensor_model_parallel_src_rank(),
group=vllm_ps.get_tensor_model_parallel_group())
dp_rank = torch.distributed.get_rank()
dp_size = torch.distributed.get_world_size() # not consider torch micro-dp
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if tp_size > 1:
# TODO: shall we build a micro_dp group for vllm when integrating with vLLM?
local_prompts = data.chunk(chunks=tp_size)
data = local_prompts[dp_rank % tp_size]
return data

View File

@@ -0,0 +1,428 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
"""
import torch
import torch.distributed as dist
from torch import nn
from megatron.core import parallel_state as mpu
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import get_model, unwrap_model
from verl.utils.memory_buffer import (
build_memory_buffer,
build_memory_reference_from_module,
get_weight_buffer_meta_from_module,
)
class AllGatherPPModel:
def __init__(self, model_provider) -> None:
self._pp_group = mpu.get_pipeline_model_parallel_group()
self._pp_rank = mpu.get_pipeline_model_parallel_rank()
self._pp_size = mpu.get_pipeline_model_parallel_world_size()
self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
self._model_chunk_size = self._vpp_size or 1
# each one holds a list of model_chunks in this pp stage
self._pp_models = [None] * self.pp_size
rank_list = list(range(self.pp_size))
# make current rank the last one to initialize
rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
self._this_rank_models = None
# store the parameter of each pp stage
self.memory_buffers = [None] * self.pp_size
for cur_pp_rank in rank_list:
print(
f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, '
f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB')
# since the last initialized rank is the current pp rank, after init, the pp rank is still correct
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
if cur_pp_rank != self.pp_rank:
models = get_model(model_provider, wrap_with_ddp=False)
models = nn.ModuleList(models)
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
self.pp_models[cur_pp_rank] = models
else:
# for regular model, we wrapped it with DDP
models = get_model(model_provider)
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
self._this_rank_models = nn.ModuleList(models)
self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))
self._build_param_buffer(cur_pp_rank)
self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)
# TODO: after binding to the memory buffer, we can load the checkpoint here
if cur_pp_rank != self.pp_rank:
for model in self.pp_models[cur_pp_rank]:
model.eval()
self._offload_params_to_cpu(cur_pp_rank)
def _build_param_buffer(self, pp_rank):
"""Build the parameter buffer in each pp rank"""
model = self.pp_models[pp_rank]
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
def _build_param_references(self, pp_rank, maintain_weight=False):
model = self.pp_models[pp_rank]
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
def _load_params_to_cuda(self, pp_rank, to_empty=False):
assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
for buffer in self.memory_buffers[pp_rank].values():
if not to_empty:
buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
else:
buffer.data = torch.empty_like(buffer.data, device='cuda')
# rebuild reference after loading to CUDA
self._build_param_references(pp_rank)
def _offload_params_to_cpu(self, pp_rank, to_empty=False):
assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
for buffer in self.memory_buffers[pp_rank].values():
if not to_empty:
# offload the whole memory buffer to CPU
buffer.data = buffer.data.to('cpu', non_blocking=True)
else:
buffer.data = torch.empty_like(buffer.data, device='cpu')
self._build_param_references(pp_rank)
def load_params_to_cuda(self, to_empty=False):
"""load all model params to cuda"""
for cur_pp_rank in range(self.pp_size):
if cur_pp_rank != self.pp_rank:
self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)
def allgather_params(self):
"""allgather params of all pp ranks. Return a list of handles"""
for cur_pp_rank in range(self.pp_size):
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
# NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models
for memory_buffer in self.memory_buffers[cur_pp_rank].values():
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False)
def forward(self, *inputs, **kwargs):
try:
prev_output = None
for cur_chunk_rank in range(self._model_chunk_size):
if self._vpp_size:
mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)
for cur_pp_rank in range(self.pp_size):
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
prev_output = ret
finally:
if self._vpp_size:
mpu.set_virtual_pipeline_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(self.pp_rank)
return ret
def __call__(self, *inputs, **kwargs):
return self.forward(*inputs, **kwargs)
def eval(self):
for model in self.pp_models[self.pp_rank]:
model.eval()
def train(self):
for model in self.pp_models[self.pp_rank]:
model.train()
def offload_params_to_cpu(self, to_empty=False):
"""offload params of models that are not of current pp rank to cpu"""
for cur_pp_rank in range(self.pp_size):
if cur_pp_rank != self.pp_rank:
self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)
def get_all_params(self):
"""Get all the parameters of the models in all pp ranks
Returns:
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
tensors of each model chunk
"""
params = []
for pp_rank in range(self.pp_size):
params.append([])
for model_chunk_idx in range(len(self.pp_models[pp_rank])):
params[pp_rank].append({})
pp_model = self.pp_models[pp_rank][model_chunk_idx]
pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module
for name, param in pp_model.named_parameters():
# NOTE(gh) workaround: should not get lora params for inference
if 'lora' in name:
continue
params[pp_rank][model_chunk_idx][name] = param
return params
def update_this_rank_models(self, new_models):
self._this_rank_models = new_models
self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))
@property
def this_rank_models(self):
return self._this_rank_models
@property
def pp_size(self):
return self._pp_size
@property
def pp_rank(self):
return self._pp_rank
@property
def pp_group(self):
return self._pp_group
@property
def pp_models(self):
return self._pp_models
"""
Megatron Hybrid Engine:
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from .base import BaseShardingManager
import torch
from torch import nn
import torch.distributed
from torch.distributed import new_group
from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
import verl.utils.megatron.tensor_parallel as tp_utils
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.third_party.vllm import LLM
from verl.utils.model import normalize_pp_vpp_params
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
_MICRO_DATA_PARALLEL_GROUP = None
class MegatronVLLMShardingManager(BaseShardingManager):
def __init__(self, module: AllGatherPPModel, inference_engine: LLM, model_config, layer_name_mapping):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.layer_name_mapping = layer_name_mapping
# initialize micro_dp group for vllm inference
global _MICRO_DATA_PARALLEL_GROUP
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
train_tensor_parallel_size = mpu.get_tensor_model_parallel_world_size()
infer_tensor_parallel_size = vllm_ps.get_tensor_model_parallel_world_size()
# TODO(sgm): this may not be true for FSDP -> vLLM
assert infer_tensor_parallel_size <= train_tensor_parallel_size, \
'Not implemented for infer_tp > train_tp'
assert train_tensor_parallel_size % infer_tensor_parallel_size == 0
micro_dp_size = train_tensor_parallel_size // infer_tensor_parallel_size
num_micro_dp_groups = world_size // micro_dp_size
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
for i in range(num_micro_dp_groups):
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
group = new_group(ranks=ranks)
if rank in ranks:
_MICRO_DATA_PARALLEL_GROUP = group
def default_tp_concat_fn(self, name, param, infer_params, model_config):
"""
name: name of the parameter
param: training parameters
infer_params (List[torch.Tensor]): a list of parameters all-gathered from micro_dp_group
model_config: huggingface model_config
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
definition so that it is model-agnostic. If the model doesn't implement this function,
we can throw an error to force user disable TP HybridEngine.
"""
if self.layer_name_mapping.get("qkv_layer_name") in name:
# if the tensor is qkv, for each param on tp, split into q, k, v
# concat q, k, v separately.
q_lst = []
k_lst = []
v_lst = []
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
for infer_param in infer_params:
q, k, v = infer_param.split(split_size)
q_lst.append(q)
k_lst.append(k)
v_lst.append(v)
q = torch.cat(q_lst, dim=0)
k = torch.cat(k_lst, dim=0)
v = torch.cat(v_lst, dim=0)
infer_params = torch.cat((q, k, v), dim=0)
elif self.layer_name_mapping.get("gate_proj_layer_name") in name:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in infer_params:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
infer_params = torch.cat((gate, up), dim=0)
else:
# concat tensor
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))
return infer_params
def _post_process_params(self, params):
"""
For each param, if it is a tp-splited param, we all-gather from micro_dp group.
"""
# here the params are in train tp format. we iterate params and all-gather
# TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer.
# In this way, all the params in the original memory_buffers and can be offload.
micro_dp_size = get_micro_data_parallel_world_size()
micro_dp_group = get_micro_data_parallel_group()
if micro_dp_size <= 1:
return
origin_params = {}
for name in params.keys():
param = params[name]
if tp_utils.is_tensor_parallel_param(param):
# allocate a new tensor with proper size
infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)]
torch.distributed.all_gather(infer_params, param, group=micro_dp_group)
infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config)
# replace with original param
params[name] = infer_params
origin_params[name] = param
return origin_params
def __enter__(self):
# create a new cuda space for parameters not in this pp rank
self.module.load_params_to_cuda()
# broadcast the parameters from pp rank to other ranks
self.module.allgather_params()
# obtain name to parameters in pp/vpp
params = self.module.get_all_params()
# bind the params to inference engine
self.params = normalize_pp_vpp_params(params=params,
num_hidden_layers=self.model_config.num_hidden_layers,
layer_name='layers')
self.origin_params = self._post_process_params(self.params)
self.inference_engine.sync_model_weights(self.params, load_format='megatron')
def __exit__(self, exc_type, exc_value, traceback):
# offload parameters doesn't belong to this pp rank
self.module.offload_params_to_cpu()
# FIXME(sgm): the best practice is to delete the cuda tensor
# rebind the model weights, can be any cpu tensor
if get_micro_data_parallel_world_size() > 1:
for name in self.params.keys():
self.params[name] = self.origin_params[name]
# self.inference_engine.sync_model_weights(params)
self.inference_engine.offload_model_weights()
self.module.train()
# add empty cache after each compute
torch.cuda.empty_cache()
def preprocess_data(self, data: DataProto) -> DataProto:
# prompts are identical for each training tp. We select for each inference tp
micro_dp_size = get_micro_data_parallel_world_size()
micro_dp_rank = get_micro_data_parallel_rank()
# broadcast from tp=0 to other tp ranks
broadcast_dict_tensor(data.batch,
src=mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if micro_dp_size > 1:
local_prompts = data.chunk(chunks=micro_dp_size)
data = local_prompts[micro_dp_rank]
return data
def postprocess_data(self, data: DataProto) -> DataProto:
meta_info = data.meta_info
# all gather batch among micro-dp groups
micro_dp_size = get_micro_data_parallel_world_size()
if micro_dp_size > 1:
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=get_micro_data_parallel_world_size(),
group=get_micro_data_parallel_group(),
dim=0)
# all gather batch among pp group
if meta_info.get('allgather_pp_output', True):
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=mpu.get_pipeline_model_parallel_world_size(),
group=mpu.get_pipeline_model_parallel_group(),
dim=0)
return data
"""
Micro Data parallel group
"""
def get_micro_data_parallel_group():
assert _MICRO_DATA_PARALLEL_GROUP is not None
return _MICRO_DATA_PARALLEL_GROUP
def get_micro_data_parallel_world_size():
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
def get_micro_data_parallel_rank():
return torch.distributed.get_rank(group=get_micro_data_parallel_group())