Initial commit
This commit is contained in:
13
verl/workers/__init__.py
Normal file
13
verl/workers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
18
verl/workers/actor/__init__.py
Normal file
18
verl/workers/actor/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BasePPOActor
|
||||
from .dp_actor import DataParallelPPOActor
|
||||
|
||||
__all__ = ["BasePPOActor", "DataParallelPPOActor"]
|
||||
66
verl/workers/actor/base.py
Normal file
66
verl/workers/actor/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The base class for Actor
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, Dict
|
||||
|
||||
from verl import DataProto
|
||||
import torch
|
||||
|
||||
__all__ = ['BasePPOActor']
|
||||
|
||||
|
||||
class BasePPOActor(ABC):
|
||||
|
||||
def __init__(self, config):
|
||||
"""The base class for PPO actor
|
||||
|
||||
Args:
|
||||
config (DictConfig): a config passed to the PPOActor. We expect the type to be
|
||||
DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
|
||||
"""Compute logits given a batch of data.
|
||||
|
||||
Args:
|
||||
data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
|
||||
```attention_mask``` and ```position_ids```.
|
||||
|
||||
Returns:
|
||||
DataProto: a DataProto containing the key ```log_probs```
|
||||
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_policy(self, data: DataProto) -> Dict:
|
||||
"""Update the policy with an iterator of DataProto
|
||||
|
||||
Args:
|
||||
data (DataProto): an iterator over the DataProto that returns by
|
||||
```make_minibatch_iterator```
|
||||
|
||||
Returns:
|
||||
Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
|
||||
such as ```loss```, ```grad_norm```, etc,.
|
||||
|
||||
"""
|
||||
pass
|
||||
290
verl/workers/actor/dp_actor.py
Normal file
290
verl/workers/actor/dp_actor.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Single Process Actor
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.actor import BasePPOActor
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import logprobs_from_logits, masked_mean
|
||||
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
|
||||
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
|
||||
|
||||
__all__ = ['DataParallelPPOActor']
|
||||
|
||||
|
||||
class DataParallelPPOActor(BasePPOActor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
actor_module: nn.Module,
|
||||
actor_optimizer: torch.optim.Optimizer = None,
|
||||
):
|
||||
"""When optimizer is None, it is Reference Policy"""
|
||||
super().__init__(config)
|
||||
self.actor_module = actor_module
|
||||
self.actor_optimizer = actor_optimizer
|
||||
self.use_remove_padding = self.config.get('use_remove_padding', False)
|
||||
print(f'Actor use_remove_padding={self.use_remove_padding}')
|
||||
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
|
||||
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
|
||||
|
||||
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
|
||||
|
||||
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
entropy: # (bs, response_len)
|
||||
log_probs: # (bs, response_len)
|
||||
"""
|
||||
response_length = micro_batch['responses'].size(-1)
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch_size, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# for compute the log_prob
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
|
||||
|
||||
# pad and slice the inputs if sp > 1
|
||||
if self.use_ulysses_sp:
|
||||
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
|
||||
position_ids_rmpad, \
|
||||
sp_size=self.ulysses_sequence_parallel_size)
|
||||
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
|
||||
self.ulysses_sequence_parallel_size)
|
||||
|
||||
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
|
||||
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.actor_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
|
||||
|
||||
logits_rmpad.div_(temperature)
|
||||
|
||||
# compute entropy
|
||||
entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
|
||||
|
||||
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
|
||||
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
|
||||
|
||||
# gather log_prob if sp > 1
|
||||
if self.use_ulysses_sp:
|
||||
# gather and unpad for the ulysses sp
|
||||
log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
|
||||
gather_dim=0,
|
||||
unpad_dim=0,
|
||||
padding_size=pad_size)
|
||||
# pad back to (bsz, seqlen)
|
||||
full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
|
||||
# only return response part:
|
||||
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
|
||||
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
|
||||
|
||||
else: # not using rmpad and no ulysses sp
|
||||
output = self.actor_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
logits = output.logits
|
||||
logits.div_(temperature)
|
||||
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
|
||||
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
|
||||
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
|
||||
|
||||
return entropy, log_probs
|
||||
|
||||
def _optimizer_step(self):
|
||||
assert self.config.grad_clip is not None
|
||||
|
||||
if isinstance(self.actor_module, FSDP):
|
||||
grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
|
||||
self.actor_optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
|
||||
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
|
||||
|
||||
Args:
|
||||
data (DataProto): a DataProto containing keys
|
||||
|
||||
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
|
||||
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
|
||||
|
||||
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
|
||||
|
||||
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
|
||||
|
||||
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the log_prob tensor
|
||||
"""
|
||||
# set to eval
|
||||
self.actor_module.eval()
|
||||
|
||||
micro_batch_size = data.meta_info['micro_batch_size']
|
||||
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
|
||||
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
|
||||
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
|
||||
if use_dynamic_bsz:
|
||||
# split using dynamic bsz
|
||||
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
|
||||
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = batch.split(micro_batch_size)
|
||||
|
||||
log_probs_lst = []
|
||||
for micro_batch in micro_batches:
|
||||
with torch.no_grad():
|
||||
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
|
||||
log_probs_lst.append(log_probs)
|
||||
log_probs = torch.concat(log_probs_lst, dim=0)
|
||||
|
||||
if use_dynamic_bsz:
|
||||
indices = list(itertools.chain.from_iterable(indices))
|
||||
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
|
||||
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
||||
log_probs = log_probs[revert_indices]
|
||||
|
||||
return log_probs
|
||||
|
||||
def update_policy(self, data: DataProto):
|
||||
# make sure we are in training mode
|
||||
self.actor_module.train()
|
||||
|
||||
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
|
||||
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
|
||||
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
|
||||
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
|
||||
if self.config.state_masking:
|
||||
select_keys.append('loss_mask')
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append('ref_log_prob')
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
|
||||
# Split to make minibatch iterator for updating the actor
|
||||
# See PPO paper for details. https://arxiv.org/abs/1707.06347
|
||||
dataloader = batch.split(self.config.ppo_mini_batch_size)
|
||||
|
||||
metrics = {}
|
||||
for batch_idx, data in enumerate(dataloader):
|
||||
# split batch into micro_batches
|
||||
mini_batch = data
|
||||
if self.config.use_dynamic_bsz:
|
||||
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
|
||||
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
|
||||
else:
|
||||
# split batch into micro_batches
|
||||
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
|
||||
for data in micro_batches:
|
||||
data = data.cuda() # actor device is cpu when using offload
|
||||
responses = data['responses']
|
||||
response_length = responses.size(1)
|
||||
attention_mask = data['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
if self.config.state_masking:
|
||||
response_mask = data['loss_mask']
|
||||
old_log_prob = data['old_log_probs']
|
||||
advantages = data['advantages']
|
||||
|
||||
clip_ratio = self.config.clip_ratio
|
||||
entropy_coeff = self.config.entropy_coeff
|
||||
|
||||
# all return: (bsz, response_length)
|
||||
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)
|
||||
|
||||
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
eos_mask=response_mask,
|
||||
cliprange=clip_ratio)
|
||||
# compute entropy loss from entropy
|
||||
entropy_loss = verl_F.masked_mean(entropy, response_mask)
|
||||
|
||||
# compute policy loss
|
||||
policy_loss = pg_loss - entropy_loss * entropy_coeff
|
||||
|
||||
if self.config.use_kl_loss:
|
||||
ref_log_prob = data['ref_log_prob']
|
||||
# compute kl loss
|
||||
kld = core_algos.kl_penalty(logprob=log_prob,
|
||||
ref_logprob=ref_log_prob,
|
||||
kl_penalty=self.config.kl_loss_type)
|
||||
kl_loss = masked_mean(kld, response_mask)
|
||||
|
||||
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
|
||||
metrics['actor/kl_loss'] = kl_loss.detach().item()
|
||||
metrics['actor/kl_coef'] = self.config.kl_loss_coef
|
||||
|
||||
loss = policy_loss / self.gradient_accumulation
|
||||
loss.backward()
|
||||
|
||||
data = {
|
||||
'actor/entropy_loss': entropy_loss.detach().item(),
|
||||
'actor/pg_loss': pg_loss.detach().item(),
|
||||
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
|
||||
'actor/ppo_kl': ppo_kl.detach().item(),
|
||||
}
|
||||
append_to_dict(metrics, data)
|
||||
|
||||
grad_norm = self._optimizer_step()
|
||||
data = {'actor/grad_norm': grad_norm.detach().item()}
|
||||
append_to_dict(metrics, data)
|
||||
self.actor_optimizer.zero_grad()
|
||||
return metrics
|
||||
368
verl/workers/actor/megatron_actor.py
Normal file
368
verl/workers/actor/megatron_actor.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Megatron Actor.
|
||||
In megatron actor, the differences are:
|
||||
1. We only make minibatch
|
||||
|
||||
Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Iterable, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed
|
||||
# from megatron import get_args
|
||||
from megatron.optimizer import DistributedOptimizer
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import ModelParallelConfig
|
||||
from megatron.core.pipeline_parallel import get_forward_backward_func
|
||||
# from megatron.core.optimizer import DistributedOptimizer
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from verl.utils.megatron.tensor_parallel import vocab_parallel_compute_entropy_loss, vocab_parallel_log_probs_from_logits
|
||||
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.actor import BasePPOActor
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
|
||||
|
||||
__all__ = ['MegatronPPOActor']
|
||||
|
||||
|
||||
class MegatronPPOActor(BasePPOActor):
|
||||
|
||||
def __init__(self, config, model_config, megatron_config: ModelParallelConfig, actor_module: nn.ModuleList,
|
||||
actor_optimizer: DistributedOptimizer, actor_optimizer_config: OptimizerConfig):
|
||||
"""MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.
|
||||
|
||||
Args:
|
||||
config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain
|
||||
|
||||
``ppo_micro_batch_size``: minibatch size when updating ppo.
|
||||
|
||||
``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.
|
||||
|
||||
``ppo_epochs``: number of epochs to update the actor using the batch data.
|
||||
|
||||
``shuffle``: whether to shuffle the data after each ppo epoch.
|
||||
|
||||
``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.
|
||||
|
||||
``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.
|
||||
model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
|
||||
``model_config.hidden_size``
|
||||
megatron_config (OmegaConf): megatron configuration. It must contains
|
||||
|
||||
``sequence_parallel_enabled``: whether the sequence parallel is enabled.
|
||||
|
||||
``param_dtype``: the dtype of the parameters.
|
||||
|
||||
``virtual_pipeline_model_parallel_size``: virtual pipeline model parallel size. a.k.a number of chunks in each pp stage.
|
||||
actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage.
|
||||
each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for more details.
|
||||
The actor module has some constraints to follow in order to use the updating logics implemented here
|
||||
|
||||
1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an
|
||||
optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
|
||||
(https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).
|
||||
|
||||
2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
|
||||
where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
|
||||
of the hidden state is [total_nnz // tp, 1, hidden_size].
|
||||
actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements
|
||||
zero1 optimizer that shards the optimizer state across dp ranks.
|
||||
|
||||
>>> def megatron_actor_model_provider(pre_process, post_process):
|
||||
>>> vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
|
||||
>>> parallel_model = ParallelMistralForCausalLMRmPadPP(config=actor_model_config,
|
||||
>>> megatron_config=megatron_config,
|
||||
>>> pre_process=pre_process,
|
||||
>>> post_process=post_process).cuda()
|
||||
>>> return parallel_model
|
||||
>>> from megatron.training import get_model
|
||||
>>> from megatron.optimizer import get_megatron_optimizer
|
||||
>>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)
|
||||
>>> actor_module = nn.ModuleList(actor_module)
|
||||
>>> actor_optimizer = get_megatron_optimizer(actor_module)
|
||||
>>> actor = MegatronPPOActor(config=config,
|
||||
>>> model_config=actor_model_config,
|
||||
>>> megatron_config=megatron_config,
|
||||
>>> actor_module=actor_module,
|
||||
>>> actor_optimizer=actor_optimizer)
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.model_config = model_config
|
||||
self.megatron_config = megatron_config
|
||||
# self.megatron_args = get_args()
|
||||
self.actor_module = actor_module
|
||||
self.actor_optimizer: DistributedOptimizer = actor_optimizer
|
||||
self.actor_optimizer_config = actor_optimizer_config
|
||||
|
||||
self.optimizer_step_args = OmegaConf.create({
|
||||
'skip_grad': None,
|
||||
'overlap_dp_param_comm': False,
|
||||
'overlap_dp_grad_comm': False,
|
||||
'gradient_accumulation_steps': 1,
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'DDP_impl': 'local',
|
||||
'layernorm_allreduce_bucket_threshold': 0,
|
||||
'pipeline_model_parallel_split_rank': None,
|
||||
'reduce_grads_use_alltoall': False
|
||||
})
|
||||
|
||||
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
|
||||
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
|
||||
|
||||
Args:
|
||||
data (DataProto): a DataProto containing keys
|
||||
|
||||
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
|
||||
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
|
||||
|
||||
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
|
||||
|
||||
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
|
||||
|
||||
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
|
||||
|
||||
Returns:
|
||||
DataProto: torch.Tensor: the log_prob tensor
|
||||
"""
|
||||
data.batch = data.batch.contiguous()
|
||||
|
||||
def compute_logprobs_fn(output, data):
|
||||
response = data['responses']
|
||||
response_length = response.size(1)
|
||||
logits = output['logits']
|
||||
logits = logits[:, -response_length - 1:-1]
|
||||
log_probs = vocab_parallel_log_probs_from_logits(logits, response)
|
||||
return {'log_probs': log_probs}
|
||||
|
||||
# We make recompute_old_log_prob by default here.
|
||||
# TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside
|
||||
recompute_old_log_prob = self.config.get('recompute_old_log_prob', True)
|
||||
|
||||
if recompute_old_log_prob or 'old_log_probs' not in data.batch.keys():
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
input_ids = batch['input_ids']
|
||||
batch_size = input_ids.size(0)
|
||||
response = batch['responses']
|
||||
response_length = response.size(1)
|
||||
with torch.no_grad():
|
||||
output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn)
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# only on last rank. It should be on every tp rank
|
||||
log_probs = torch.cat([o['log_probs'] for o in output], dim=0) # (bs, seq_size)
|
||||
log_probs = log_probs.to(torch.float32)
|
||||
else:
|
||||
log_probs = torch.empty(size=(batch_size, response_length),
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device)
|
||||
|
||||
# broadcast across pp ranks
|
||||
torch.distributed.broadcast(tensor=log_probs,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group(),
|
||||
async_op=False)
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return log_probs
|
||||
|
||||
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
|
||||
"""Make minibatch iterator for updating the actor
|
||||
|
||||
Args:
|
||||
data (DataProto): a DataProto containing keys
|
||||
|
||||
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length``
|
||||
|
||||
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
|
||||
|
||||
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
|
||||
|
||||
``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:]
|
||||
|
||||
``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses.
|
||||
|
||||
``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses.
|
||||
See PPO paper for details. https://arxiv.org/abs/1707.06347
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
|
||||
data = data.select(batch_keys=select_keys)
|
||||
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
|
||||
epochs=self.config.ppo_epochs,
|
||||
dataloader_kwargs={'shuffle': self.config.shuffle})
|
||||
|
||||
def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None):
|
||||
"""
|
||||
We assume:
|
||||
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
|
||||
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
|
||||
"""
|
||||
# broadcast from last pp rank to all other pp ranks
|
||||
# TODO: actually, we just need to control the sampling order.
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
# split into micro-batches
|
||||
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
|
||||
|
||||
if data.meta_info.get('micro_batch_size', None) is not None:
|
||||
batch_size = data.meta_info['micro_batch_size']
|
||||
else:
|
||||
batch_size = self.config.ppo_micro_batch_size
|
||||
batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
|
||||
# compute input shapes for pp stages
|
||||
input_shapes = compute_transformers_input_shapes(
|
||||
batches,
|
||||
meta_info={
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'hidden_size': self.model_config.hidden_size
|
||||
})
|
||||
n_micro_batch = len(batches)
|
||||
seq_len = batches[0]['input_ids'].shape[1]
|
||||
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
|
||||
def loss_func(output, data, meta_info):
|
||||
if forward_only:
|
||||
if post_process_fn is None:
|
||||
return 1.0, {'logits': output.logits}
|
||||
else:
|
||||
return 1.0, post_process_fn(output, data)
|
||||
|
||||
responses = data['responses']
|
||||
response_length = responses.size(1)
|
||||
attention_mask = data['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
old_log_prob = data['old_log_probs']
|
||||
advantages = data['advantages']
|
||||
|
||||
clip_ratio = meta_info['clip_ratio']
|
||||
entropy_coeff = meta_info['entropy_coeff']
|
||||
|
||||
# compute policy loss
|
||||
logits = output.logits
|
||||
logits = logits[:, -response_length - 1:-1]
|
||||
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
|
||||
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
eos_mask=response_mask,
|
||||
cliprange=clip_ratio)
|
||||
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask)
|
||||
policy_loss = pg_loss - entropy_loss * entropy_coeff
|
||||
# return loss and stats
|
||||
stats = {
|
||||
'actor/entropy_loss': entropy_loss.detach().item(),
|
||||
'actor/pg_loss': pg_loss.detach().item(),
|
||||
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
|
||||
'actor/ppo_kl': ppo_kl.detach().item()
|
||||
}
|
||||
return policy_loss, stats
|
||||
|
||||
def forward_step(batch_iter, model):
|
||||
batch = next(batch_iter)
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
position_ids = batch['position_ids']
|
||||
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
if forward_only:
|
||||
meta_info = None
|
||||
else:
|
||||
meta_info = {'clip_ratio': self.config.clip_ratio, 'entropy_coeff': self.config.entropy_coeff}
|
||||
return output, partial(loss_func, data=batch, meta_info=meta_info)
|
||||
|
||||
# batch should be a list of batches inside micro-batches
|
||||
batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module))
|
||||
|
||||
# TODO: we may use the new schedule instead
|
||||
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.actor_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
input_shapes=input_shapes, # must set for flash-attn sequence packing
|
||||
seq_length=batch_size * seq_len, # no use when input_shapes was set
|
||||
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
|
||||
micro_batch_size=1, # no use when input_shapes was set
|
||||
forward_only=forward_only,
|
||||
)
|
||||
else:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.actor_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
seq_length=batch_size * seq_len, # in use for pp = 1
|
||||
hidden_size=self.model_config.hidden_size, # in use for pp = 1
|
||||
micro_batch_size=1, # in use for pp = 1
|
||||
forward_only=forward_only,
|
||||
)
|
||||
# loss_reduces contains the stats returned from loss_func
|
||||
return losses_reduced
|
||||
|
||||
def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:
|
||||
"""Update the policy with an iterator of DataProto
|
||||
|
||||
Args:
|
||||
dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
|
||||
The keys of each data batch is described in the make_minibatch_iterator.
|
||||
|
||||
Returns:
|
||||
Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
|
||||
and users have to combine the output in each dp rank manually.
|
||||
|
||||
"""
|
||||
metrics = {}
|
||||
for data in dataloader:
|
||||
# data = data.batch.to(self.actor_module.device)
|
||||
self.actor_optimizer.zero_grad()
|
||||
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
|
||||
for chunk in self.actor_module:
|
||||
# if use distributed optimizer, zero grad buffer will be handled by optimizer
|
||||
chunk.zero_grad_buffer(zero_buffer=(not self.actor_optimizer_config.use_distributed_optimizer))
|
||||
|
||||
metric_micro_batch = self.forward_backward_batch(data)
|
||||
for metric in metric_micro_batch:
|
||||
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
|
||||
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step(
|
||||
self.megatron_config, self.megatron_config.timers)
|
||||
if update_successful:
|
||||
# allgather already execute in optimizer.step in new megatron
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for metric in metric_micro_batch:
|
||||
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return metrics
|
||||
18
verl/workers/critic/__init__.py
Normal file
18
verl/workers/critic/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BasePPOCritic
|
||||
from .dp_critic import DataParallelPPOCritic
|
||||
|
||||
__all__ = ["BasePPOCritic", "DataParallelPPOCritic"]
|
||||
40
verl/workers/critic/base.py
Normal file
40
verl/workers/critic/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Base class for a critic
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
__all__ = ['BasePPOCritic']
|
||||
|
||||
|
||||
class BasePPOCritic(ABC):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def compute_values(self, data: DataProto) -> torch.Tensor:
|
||||
"""Compute values"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_critic(self, data: DataProto):
|
||||
"""Update the critic"""
|
||||
pass
|
||||
204
verl/workers/critic/dp_critic.py
Normal file
204
verl/workers/critic/dp_critic.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Implement a multiprocess PPOCritic
|
||||
"""
|
||||
import itertools
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn, optim
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.critic import BasePPOCritic
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
|
||||
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
|
||||
|
||||
__all__ = ['DataParallelPPOCritic']
|
||||
|
||||
|
||||
class DataParallelPPOCritic(BasePPOCritic):
|
||||
|
||||
def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):
|
||||
super().__init__(config=config)
|
||||
self.critic_module = critic_module
|
||||
self.critic_optimizer = critic_optimizer
|
||||
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
|
||||
print(f'Critic use_remove_padding={self.use_remove_padding}')
|
||||
|
||||
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
|
||||
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
|
||||
|
||||
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
|
||||
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
response_length = micro_batch['responses'].size(-1)
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
|
||||
attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# pad and slice the inputs if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
|
||||
position_ids_rmpad, \
|
||||
sp_size=self.ulysses_sequence_parallel_size)
|
||||
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.critic_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values_rmpad = output.logits
|
||||
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
values_rmpad = gather_outpus_and_unpad(values_rmpad,
|
||||
gather_dim=0,
|
||||
unpad_dim=0,
|
||||
padding_size=pad_size)
|
||||
|
||||
# pad it back
|
||||
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
|
||||
values = values[:, -response_length - 1:-1]
|
||||
else:
|
||||
output = self.critic_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values = output.logits
|
||||
values = values[:, -response_length - 1:-1].squeeze(-1)
|
||||
return values
|
||||
|
||||
def _optimizer_step(self):
|
||||
assert self.config.grad_clip is not None
|
||||
|
||||
if isinstance(self.critic_module, FSDP):
|
||||
grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)
|
||||
self.critic_optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
def compute_values(self, data: DataProto) -> torch.Tensor:
|
||||
self.critic_module.eval()
|
||||
micro_batch_size = data.meta_info['micro_batch_size']
|
||||
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
|
||||
|
||||
if use_dynamic_bsz:
|
||||
# split using dynamic bsz
|
||||
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
|
||||
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = batch.split(micro_batch_size)
|
||||
|
||||
values_lst = []
|
||||
for micro_batch in micro_batches:
|
||||
with torch.no_grad():
|
||||
values = self._forward_micro_batch(micro_batch)
|
||||
values_lst.append(values)
|
||||
values = torch.concat(values_lst, dim=0)
|
||||
responses = data.batch['responses']
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_length = responses.size(1)
|
||||
values = values * attention_mask[:, -response_length - 1:-1]
|
||||
|
||||
if use_dynamic_bsz:
|
||||
indices = list(itertools.chain.from_iterable(indices))
|
||||
assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}"
|
||||
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
||||
values = values[revert_indices]
|
||||
|
||||
return values
|
||||
|
||||
def update_critic(self, data: DataProto):
|
||||
# make sure we are in training mode
|
||||
self.critic_module.train()
|
||||
metrics = {}
|
||||
|
||||
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
# Split to make minibatch iterator for updating the actor
|
||||
# See PPO paper for details. https://arxiv.org/abs/1707.06347
|
||||
dataloader = batch.split(self.config.ppo_mini_batch_size)
|
||||
|
||||
for batch_idx, data in enumerate(dataloader):
|
||||
# split batch into micro_batches
|
||||
mini_batch = data
|
||||
if self.config.use_dynamic_bsz:
|
||||
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
|
||||
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
|
||||
for data in micro_batches:
|
||||
data = data.cuda() # critic device is cpu when using offload
|
||||
input_ids = data['input_ids']
|
||||
responses = data['responses']
|
||||
attention_mask = data['attention_mask']
|
||||
position_ids = data['position_ids']
|
||||
values = data['values']
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length - 1:-1]
|
||||
|
||||
vpreds = self._forward_micro_batch(data)
|
||||
|
||||
# assert not torch.any(torch.isnan(vpreds)).item()
|
||||
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
cliprange_value=self.config.cliprange_value)
|
||||
loss = vf_loss / self.gradient_accumulation
|
||||
loss.backward()
|
||||
|
||||
data = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
}
|
||||
|
||||
append_to_dict(metrics, data)
|
||||
|
||||
grad_norm = self._optimizer_step()
|
||||
data = {'critic/grad_norm': grad_norm.detach().item()}
|
||||
append_to_dict(metrics, data)
|
||||
self.critic_optimizer.zero_grad()
|
||||
return metrics
|
||||
229
verl/workers/critic/megatron_critic.py
Normal file
229
verl/workers/critic/megatron_critic.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Implement a multiprocess PPOCritic
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.critic import BasePPOCritic
|
||||
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
|
||||
from megatron.optimizer import DistributedOptimizer
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core.pipeline_parallel import get_forward_backward_func
|
||||
|
||||
|
||||
class MegatronPPOCritic(BasePPOCritic):
|
||||
|
||||
def __init__(self, config, model_config, megatron_config, critic_module: nn.ModuleList,
|
||||
critic_optimizer: DistributedOptimizer, critic_optimizer_config: OptimizerConfig):
|
||||
super().__init__(config=config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.megatron_config = megatron_config
|
||||
|
||||
self.critic_module = critic_module
|
||||
self.critic_optimizer = critic_optimizer
|
||||
self.critic_optimizer_config = critic_optimizer_config
|
||||
|
||||
# we create a separate nametuple for optimizer step so that global args won't affect it.
|
||||
self.optimizer_step_args = OmegaConf.create({
|
||||
'skip_grad': None,
|
||||
'overlap_dp_param_comm': False,
|
||||
'overlap_dp_grad_comm': False,
|
||||
'gradient_accumulation_steps': 1,
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'DDP_impl': 'local',
|
||||
'layernorm_allreduce_bucket_threshold': 0,
|
||||
'pipeline_model_parallel_split_rank': None,
|
||||
'reduce_grads_use_alltoall': False
|
||||
})
|
||||
|
||||
if self.config.kl_ctrl.type == 'fixed':
|
||||
self.kl_ctrl = core_algos.FixedKLController(kl_coef=self.config.kl_ctrl.kl_coef)
|
||||
elif self.config.kl_ctrl.type == 'adaptive':
|
||||
assert self.config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {self.config.kl_ctrl.horizon}'
|
||||
self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=self.config.kl_ctrl.kl_coef,
|
||||
target_kl=self.config.kl_ctrl.target_kl,
|
||||
horizon=self.config.kl_ctrl.horizon)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_values(self, data: DataProto) -> DataProto:
|
||||
# data.batch = data.batch.to(self.critic_module.module.device)
|
||||
responses = data.batch['responses']
|
||||
attention_mask = data.batch['attention_mask']
|
||||
response_length = responses.size(1)
|
||||
with torch.no_grad():
|
||||
output = self.forward_backward_batch(data=data, forward_only=True)
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# only on last rank. It should be on every tp rank
|
||||
values = torch.cat([o['vpreds'] for o in output], dim=0) # (bs, seq_size, vocal_size)
|
||||
values = values.to(torch.float32)
|
||||
else:
|
||||
values = torch.empty_like(attention_mask, dtype=torch.float32)
|
||||
|
||||
# each tp ranks should contain the same value
|
||||
values = values * attention_mask
|
||||
values = values[:, -response_length - 1:-1]
|
||||
values = values.contiguous()
|
||||
|
||||
# sync among pp ranks
|
||||
torch.distributed.broadcast(tensor=values,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return values
|
||||
|
||||
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
|
||||
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
|
||||
data = data.select(batch_keys=select_keys)
|
||||
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
|
||||
epochs=self.config.ppo_epochs,
|
||||
dataloader_kwargs={'shuffle': self.config.shuffle})
|
||||
|
||||
def forward_backward_batch(self, data: DataProto, forward_only=False):
|
||||
# broadcast from last pp rank to all other pp ranks
|
||||
data.batch = data.batch.contiguous()
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
# split into micro-batches
|
||||
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
|
||||
batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size)
|
||||
n_micro_batch = len(batches)
|
||||
seq_len = batches[0]['input_ids'].shape[1]
|
||||
|
||||
# compute input shapes for pp stages
|
||||
input_shapes = compute_transformers_input_shapes(
|
||||
batches,
|
||||
meta_info={
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'hidden_size': self.model_config.hidden_size
|
||||
})
|
||||
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
|
||||
def loss_func(output, data, meta_info):
|
||||
if forward_only:
|
||||
return 1.0, {'vpreds': output.logits}
|
||||
|
||||
responses = data['responses']
|
||||
attention_mask = data['attention_mask']
|
||||
values = data['values']
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length:]
|
||||
|
||||
cliprange_value = self.config.cliprange_value
|
||||
|
||||
vpreds = output.logits # (bs, sequence_length)
|
||||
vpreds = vpreds[:, -response_length - 1:-1]
|
||||
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
cliprange_value=cliprange_value)
|
||||
stats = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
}
|
||||
|
||||
return vf_loss, stats
|
||||
|
||||
def forward_step(batch_iter, model):
|
||||
batch = next(batch_iter)
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
position_ids = batch['position_ids']
|
||||
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
return output, partial(loss_func, data=batch, meta_info={})
|
||||
|
||||
# batch should be a list of batches inside micro-batches
|
||||
batch_generator = make_batch_generator(batches, vpp_size=len(self.critic_module))
|
||||
|
||||
# TODO: we may use the new schedule instead
|
||||
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.critic_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
input_shapes=input_shapes, # must set for flash-attn sequence packing
|
||||
seq_length=self.config.ppo_micro_batch_size * seq_len, # no use when input_shapes was set
|
||||
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
|
||||
micro_batch_size=1, # no use when input_shapes was set
|
||||
forward_only=forward_only,
|
||||
)
|
||||
else:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.critic_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
seq_length=self.config.ppo_micro_batch_size * seq_len, # in use for pp = 1
|
||||
hidden_size=self.model_config.hidden_size, # in use for pp = 1
|
||||
micro_batch_size=1, # in use for pp = 1
|
||||
forward_only=forward_only,
|
||||
)
|
||||
# loss_reduces contains the stats returned from loss_func
|
||||
return losses_reduced
|
||||
|
||||
def update_critic(self, dataloader: Iterable[DataProto]):
|
||||
metrics = {}
|
||||
|
||||
for data in dataloader:
|
||||
# data = data.batch.to(self.critic_module.device)
|
||||
self.critic_optimizer.zero_grad()
|
||||
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
|
||||
for chunk in self.critic_module:
|
||||
chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer))
|
||||
|
||||
metric_micro_batch = self.forward_backward_batch(data)
|
||||
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step(
|
||||
self.megatron_config, self.megatron_config.timers)
|
||||
if update_successful:
|
||||
# allgather already execute in optimizer.step in new megatron
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for metric in metric_micro_batch:
|
||||
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
return metrics
|
||||
1054
verl/workers/fsdp_workers.py
Normal file
1054
verl/workers/fsdp_workers.py
Normal file
File diff suppressed because it is too large
Load Diff
735
verl/workers/megatron_workers.py
Normal file
735
verl/workers/megatron_workers.py
Normal file
@@ -0,0 +1,735 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The main entry point to run the PPO algorithm
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
from verl.single_controller.base.megatron.worker import MegatronWorker
|
||||
from verl.workers.actor.megatron_actor import MegatronPPOActor
|
||||
from verl.workers.critic.megatron_critic import MegatronPPOCritic
|
||||
from verl.workers.sharding_manager import AllGatherPPModel
|
||||
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
|
||||
|
||||
from verl.single_controller.base.decorator import register, Dispatch
|
||||
from verl import DataProto
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
from verl.utils.model import load_megatron_model_weights
|
||||
from verl.utils.megatron_utils import init_model_parallel_config
|
||||
from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import ModelParallelConfig
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if torch.cuda.device_count() > 0:
|
||||
from megatron.core import tensor_parallel
|
||||
tensor_parallel.model_parallel_cuda_manual_seed(seed)
|
||||
# FIXME: torch cumsum not support deterministic (used in vllm sampler),
|
||||
# https://github.com/pytorch/pytorch/issues/89492
|
||||
# torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
|
||||
|
||||
class ActorRolloutRefWorker(MegatronWorker):
|
||||
"""
|
||||
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
|
||||
or a hybrid engine based on the config.rollout
|
||||
"""
|
||||
|
||||
def __init__(self, config: DictConfig, role: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# NOTE(sgm): We utilize colocate WorkerGroup by default.
|
||||
# As a result, Workers for different model share the same process.
|
||||
# Therefore, we only require one distribute initialization.
|
||||
# To utilize different parallel startegy in different models:
|
||||
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
|
||||
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
|
||||
if not torch.distributed.is_initialized():
|
||||
rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if self.config.actor.megatron.sequence_parallel:
|
||||
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
|
||||
mpu.initialize_model_parallel(
|
||||
tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=None,
|
||||
pipeline_model_parallel_split_rank=None,
|
||||
use_sharp=False,
|
||||
context_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
nccl_communicator_config_path=None,
|
||||
)
|
||||
|
||||
set_random_seed(seed=self.config.actor.megatron.seed)
|
||||
|
||||
self.role = role
|
||||
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
|
||||
|
||||
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
|
||||
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
|
||||
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
|
||||
|
||||
# TODO(sgm): Currently, we only support reference model param offload
|
||||
# will support other offload later
|
||||
self._is_offload_param = False
|
||||
self._is_offload_grad = False
|
||||
self._is_offload_optimizer = False
|
||||
|
||||
# normalize config
|
||||
if self._is_actor and self._is_rollout:
|
||||
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
|
||||
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
|
||||
self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
|
||||
self._is_offload_param = self.config.actor.get('param_offload', False)
|
||||
self._is_offload_grad = self.config.actor.get('grad_offload', False)
|
||||
self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False)
|
||||
elif self._is_ref:
|
||||
self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
|
||||
self._is_offload_param = self.config.ref.get('param_offload', False)
|
||||
|
||||
def _build_model_optimizer(self,
|
||||
model_path,
|
||||
megatron_config: ModelParallelConfig,
|
||||
optim_config,
|
||||
override_model_config,
|
||||
enable_gradient_checkpointing=False):
|
||||
from verl.utils.megatron.optimizer import get_megatron_optimizer
|
||||
from megatron.core.models.gpt.gpt_model import ModelType
|
||||
from verl.utils.model import print_model_size, update_model_config
|
||||
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
|
||||
# Step 1: initialize the tokenizer
|
||||
local_path = copy_local_path_from_hdfs(model_path)
|
||||
self.tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# Step 2: get the actor_model_config
|
||||
actor_model_config = AutoConfig.from_pretrained(local_path)
|
||||
|
||||
override_config_kwargs = {
|
||||
'bos_token_id': self.tokenizer.bos_token_id,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
}
|
||||
override_config_kwargs.update(override_model_config)
|
||||
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f'Model config after override: {actor_model_config}')
|
||||
|
||||
def megatron_actor_model_provider(pre_process, post_process):
|
||||
from verl.utils.model import get_parallel_model_from_config
|
||||
# vpp is not supported yet because it will hang for some reason. Need debugging
|
||||
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
|
||||
# this_megatron_config = copy.deepcopy(megatron_config)
|
||||
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
|
||||
parallel_model = get_parallel_model_from_config(config=actor_model_config,
|
||||
megatron_config=megatron_config,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
value=False)
|
||||
parallel_model.cuda()
|
||||
return parallel_model
|
||||
|
||||
# Step 3: initialize the megatron model
|
||||
if self._is_actor and self._is_rollout:
|
||||
# Initialize the 3D HybridEngine
|
||||
hybrid_engine = AllGatherPPModel(model_provider=megatron_actor_model_provider)
|
||||
# Fetch the model at current rank
|
||||
actor_module = hybrid_engine.this_rank_models
|
||||
if isinstance(actor_module, nn.ModuleList):
|
||||
actor_module = [actor_module[0]]
|
||||
if self.config.actor.load_weight:
|
||||
load_megatron_model_weights(self.config,
|
||||
actor_model_config,
|
||||
actor_module,
|
||||
params_dtype=megatron_config.params_dtype,
|
||||
is_value_model=False)
|
||||
|
||||
if self.rank == 0:
|
||||
print_model_size(actor_module[0])
|
||||
log_gpu_memory_usage('After AllGatherPPModel init', logger=logger)
|
||||
elif self._is_ref:
|
||||
print(f'self.config.ref.load_weight: {self.config.ref.load_weight}')
|
||||
ref_module = get_model(model_provider_func=megatron_actor_model_provider,
|
||||
model_type=ModelType.encoder_or_decoder,
|
||||
wrap_with_ddp=False)
|
||||
# ref_module = nn.ModuleList(ref_module)
|
||||
|
||||
if self.config.ref.load_weight: # should align with the actor:
|
||||
assert self.config.actor.load_weight == self.config.ref.load_weight
|
||||
print(f'load ref weight start')
|
||||
load_megatron_model_weights(self.config,
|
||||
actor_model_config,
|
||||
ref_module,
|
||||
params_dtype=megatron_config.params_dtype,
|
||||
is_value_model=False)
|
||||
log_gpu_memory_usage('After ref module init', logger=logger)
|
||||
return ref_module, actor_model_config
|
||||
|
||||
# TODO: add more optimizer args into config
|
||||
if self._is_actor:
|
||||
optim_config = init_megatron_optim_config(optim_config)
|
||||
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
|
||||
else:
|
||||
optim_config = None
|
||||
actor_optimizer = None
|
||||
|
||||
log_gpu_memory_usage('After actor optimizer init', logger=logger)
|
||||
|
||||
return actor_module, hybrid_engine, actor_optimizer, actor_model_config, optim_config
|
||||
|
||||
def _build_rollout(self):
|
||||
if self.config.rollout.name == 'vllm':
|
||||
from verl.workers.rollout.vllm_rollout import vLLMRollout
|
||||
from verl.workers.sharding_manager import MegatronVLLMShardingManager
|
||||
from verl.utils.model import normalize_pp_vpp_params
|
||||
|
||||
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
|
||||
# we will reorganize their weight format when resharding from actor to rollout.
|
||||
layer_name_mapping = {
|
||||
"qkv_layer_name":
|
||||
self.config.rollout.layer_name_map.get("qkv_layer_name", "qkv"),
|
||||
"gate_proj_layer_name":
|
||||
self.config.rollout.layer_name_map.get("gate_proj_layer_name", "linear_fc1.weight"),
|
||||
}
|
||||
|
||||
# reshard the weight partition from actor to rollout to initialize the rollout class
|
||||
# create a new cuda space for parameters not in this pp rank
|
||||
self.hybrid_engine.load_params_to_cuda()
|
||||
# broadcast the parameters from pp rank to other ranks
|
||||
self.hybrid_engine.allgather_params()
|
||||
# obtain name to parameters in pp/vpp
|
||||
params = self.hybrid_engine.get_all_params()
|
||||
# update the param name for the
|
||||
params = normalize_pp_vpp_params(params=params,
|
||||
num_hidden_layers=self.actor_model_config.num_hidden_layers,
|
||||
layer_name='layers')
|
||||
rollout = vLLMRollout(actor_module=params,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
model_hf_config=self.actor_model_config,
|
||||
train_tp=mpu.get_tensor_model_parallel_world_size())
|
||||
log_gpu_memory_usage('After building vllm rollout', logger=logger)
|
||||
|
||||
# perform weight resharding between actor and rollout
|
||||
sharding_manager = MegatronVLLMShardingManager(module=self.hybrid_engine,
|
||||
inference_engine=rollout.inference_engine,
|
||||
model_config=self.actor_model_config,
|
||||
layer_name_mapping=layer_name_mapping)
|
||||
log_gpu_memory_usage('After building sharding manager', logger=logger)
|
||||
else:
|
||||
NotImplementedError('Only vllmRollout is supported with Megatron now')
|
||||
|
||||
return rollout, sharding_manager
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
if self.config.model.get('external_lib', None) is not None:
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import importlib
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
megatron_config = OmegaConf.create({
|
||||
'sequence_parallel': self.config.actor.megatron.get('sequence_parallel', True),
|
||||
'param_dtype': PrecisionType.to_str(torch_dtype),
|
||||
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
|
||||
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
|
||||
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
|
||||
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
|
||||
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
|
||||
})
|
||||
|
||||
megatron_config = init_model_parallel_config(megatron_config)
|
||||
|
||||
if self._is_actor or self._is_rollout:
|
||||
# we need the model for actor and rollout
|
||||
if self._is_actor:
|
||||
optim_config = self.config.actor.optim
|
||||
else:
|
||||
optim_config = None
|
||||
self.actor_module, self.hybrid_engine, self.actor_optimizer, \
|
||||
self.actor_model_config, self.actor_optim_config = self._build_model_optimizer(
|
||||
model_path=self.config.model.path,
|
||||
megatron_config=megatron_config,
|
||||
optim_config=optim_config,
|
||||
override_model_config=override_model_config,
|
||||
)
|
||||
|
||||
if self._is_actor:
|
||||
self.actor = MegatronPPOActor(config=self.config.actor,
|
||||
model_config=self.actor_model_config,
|
||||
megatron_config=megatron_config,
|
||||
actor_module=self.actor_module,
|
||||
actor_optimizer=self.actor_optimizer,
|
||||
actor_optimizer_config=self.actor_optim_config)
|
||||
|
||||
if self._is_rollout:
|
||||
self.rollout, self.sharding_manager = self._build_rollout()
|
||||
|
||||
if self._is_ref:
|
||||
self.ref_module, self.ref_model_config = self._build_model_optimizer(
|
||||
model_path=self.config.model.path,
|
||||
megatron_config=megatron_config,
|
||||
optim_config=None,
|
||||
override_model_config=override_model_config,
|
||||
)
|
||||
self.ref_policy = MegatronPPOActor(config=self.config.ref,
|
||||
model_config=self.ref_model_config,
|
||||
megatron_config=megatron_config,
|
||||
actor_module=self.ref_module,
|
||||
actor_optimizer=None,
|
||||
actor_optimizer_config=None)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
|
||||
def update_actor(self, data: DataProto):
|
||||
assert self._is_actor
|
||||
|
||||
data.batch = data.batch.cuda()
|
||||
|
||||
log_gpu_memory_usage('Before update policy', logger=logger)
|
||||
|
||||
dataloader = self.actor.make_minibatch_iterator(data=data)
|
||||
metrics = self.actor.update_policy(dataloader=dataloader)
|
||||
|
||||
log_gpu_memory_usage('After update policy', logger=logger)
|
||||
|
||||
# TODO: here, we should return all metrics
|
||||
output = DataProto(meta_info={'metrics': metrics})
|
||||
output = output.to('cpu')
|
||||
torch.cuda.empty_cache()
|
||||
return output
|
||||
|
||||
# @register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)
|
||||
# def compute_log_prob(self, data: DataProto) -> DataProto:
|
||||
# assert self._is_rollout
|
||||
# output = self.actor.compute_log_prob(data=data)
|
||||
# output = DataProto.from_dict(tensors={'old_log_probs': output})
|
||||
# torch.cuda.empty_cache()
|
||||
# return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)
|
||||
def generate_sequences(self, prompts: DataProto):
|
||||
assert self._is_rollout
|
||||
|
||||
prompts.batch = prompts.batch.cuda()
|
||||
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
|
||||
prompts.meta_info.update(meta_info)
|
||||
with self.sharding_manager:
|
||||
log_gpu_memory_usage('After entering sharding manager', logger=logger)
|
||||
|
||||
prompts = self.sharding_manager.preprocess_data(prompts)
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
|
||||
log_gpu_memory_usage('After rollout generation', logger=logger)
|
||||
|
||||
output = self.sharding_manager.postprocess_data(output)
|
||||
|
||||
validate = prompts.meta_info.get('validate', False)
|
||||
if self._is_actor and not validate:
|
||||
# we should always recompute old_log_probs when it is HybridEngine
|
||||
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
|
||||
output.meta_info['temperature'] = self.config.rollout.temperature
|
||||
old_log_probs = self.actor.compute_log_prob(data=output)
|
||||
output.batch['old_log_probs'] = old_log_probs
|
||||
|
||||
output = output.to('cpu')
|
||||
# clear kv cache
|
||||
torch.cuda.empty_cache()
|
||||
log_gpu_memory_usage('After recompute log prob', logger=logger)
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
|
||||
def compute_ref_log_prob(self, data: DataProto):
|
||||
data = data.to('cuda')
|
||||
|
||||
assert self._is_ref
|
||||
if self._is_offload_param:
|
||||
load_megatron_param_and_grad(self.ref_module, torch.cuda.current_device(), self._is_offload_grad)
|
||||
|
||||
micro_batch_size = self.config.rollout.log_prob_micro_batch_size
|
||||
data.meta_info['micro_batch_size'] = micro_batch_size
|
||||
data.meta_info['temperature'] = self.config.rollout.temperature
|
||||
output = self.ref_policy.compute_log_prob(data=data)
|
||||
output = DataProto.from_dict(tensors={'ref_log_prob': output})
|
||||
output = output.to('cpu')
|
||||
if self._is_offload_param:
|
||||
offload_megatron_param_and_grad(self.ref_module, self._is_offload_grad)
|
||||
torch.cuda.empty_cache()
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_pretrained_model(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def save_checkpoint(self, checkpoint_path):
|
||||
assert self._is_actor
|
||||
pass
|
||||
|
||||
|
||||
class CriticWorker(MegatronWorker):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# NOTE(sgm): We utilize colocate WorkerGroup by default.
|
||||
# As a result, Workers for different model share the same process.
|
||||
# Therefore, we only require one distribute initialization.
|
||||
# To utilize different parallel startegy in different models:
|
||||
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
|
||||
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
|
||||
if not torch.distributed.is_initialized():
|
||||
rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if self.config.megatron.sequence_parallel:
|
||||
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
|
||||
mpu.initialize_model_parallel(
|
||||
tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=None,
|
||||
pipeline_model_parallel_split_rank=None,
|
||||
use_sharp=False,
|
||||
context_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
nccl_communicator_config_path=None,
|
||||
)
|
||||
|
||||
set_random_seed(seed=self.config.megatron.seed)
|
||||
|
||||
# normalize config
|
||||
self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
|
||||
self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
|
||||
|
||||
# TODO(sgm): support critic model offload
|
||||
|
||||
def _build_critic_model_optimizer(self,
|
||||
model_path,
|
||||
megatron_config: ModelParallelConfig,
|
||||
optim_config,
|
||||
override_model_config,
|
||||
enable_gradient_checkpointing=False):
|
||||
from megatron.core.models.gpt.gpt_model import ModelType
|
||||
from verl.utils.model import print_model_size, update_model_config
|
||||
from verl.utils.megatron.optimizer import get_megatron_optimizer
|
||||
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
|
||||
# Step 1: initialize the tokenizer
|
||||
local_path = copy_local_path_from_hdfs(model_path)
|
||||
self.tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# Step 2: get the actor_model_config
|
||||
critic_model_config = AutoConfig.from_pretrained(local_path)
|
||||
|
||||
override_config_kwargs = {
|
||||
'bos_token_id': self.tokenizer.bos_token_id,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
}
|
||||
override_config_kwargs.update(override_model_config)
|
||||
update_model_config(critic_model_config, override_config_kwargs=override_config_kwargs)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f'Model config after override: {critic_model_config}')
|
||||
|
||||
def megatron_critic_model_provider(pre_process, post_process):
|
||||
from verl.utils.model import get_parallel_model_from_config
|
||||
# TODO: support vpp here
|
||||
# vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
|
||||
# this_megatron_config = copy.deepcopy(megatron_config)
|
||||
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
|
||||
parallel_model = get_parallel_model_from_config(config=critic_model_config,
|
||||
megatron_config=megatron_config,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
value=True)
|
||||
parallel_model.cuda()
|
||||
return parallel_model
|
||||
|
||||
# Step 3: initialize the megatron model
|
||||
critic_module = get_model(model_provider_func=megatron_critic_model_provider,
|
||||
model_type=ModelType.encoder_or_decoder,
|
||||
wrap_with_ddp=True)
|
||||
# note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).
|
||||
# but here, we do not use pp (vpp) yet. For simplicity, we remove the list
|
||||
# critic_module = nn.ModuleList(critic_module)
|
||||
|
||||
if self.config.load_weight:
|
||||
load_megatron_model_weights(self.config,
|
||||
critic_model_config,
|
||||
critic_module,
|
||||
params_dtype=megatron_config.params_dtype,
|
||||
is_value_model=True)
|
||||
if self.rank == 0:
|
||||
print_model_size(critic_module[0])
|
||||
|
||||
# TODO: add more optimizer args into config
|
||||
optim_config = init_megatron_optim_config(optim_config)
|
||||
critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config)
|
||||
torch.cuda.empty_cache()
|
||||
return critic_module, critic_optimizer, critic_model_config, optim_config
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
# create critic
|
||||
from omegaconf import OmegaConf
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
|
||||
if self.config.model.get('external_lib', None) is not None:
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import importlib
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
megatron_config = OmegaConf.create({
|
||||
'sequence_parallel': self.config.megatron.get('sequence_parallel', True),
|
||||
'param_dtype': PrecisionType.to_str(torch_dtype),
|
||||
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
|
||||
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
|
||||
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
|
||||
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
|
||||
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
|
||||
})
|
||||
|
||||
megatron_config = init_model_parallel_config(megatron_config)
|
||||
|
||||
critic_module, critic_optimizer, critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer(
|
||||
model_path=self.config.model.path,
|
||||
megatron_config=megatron_config,
|
||||
optim_config=self.config.optim,
|
||||
override_model_config=override_model_config)
|
||||
self.critic = MegatronPPOCritic(config=self.config,
|
||||
model_config=critic_model_config,
|
||||
megatron_config=megatron_config,
|
||||
critic_module=critic_module,
|
||||
critic_optimizer=critic_optimizer,
|
||||
critic_optimizer_config=critic_optimizer_config)
|
||||
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
|
||||
def compute_values(self, data: DataProto):
|
||||
data = data.to('cuda')
|
||||
values = self.critic.compute_values(data=data)
|
||||
output = DataProto.from_dict(tensors={'values': values})
|
||||
output = output.to('cpu')
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
|
||||
def update_critic(self, data: DataProto):
|
||||
data = data.to('cuda')
|
||||
dataloader = self.critic.make_minibatch_iterator(data)
|
||||
metrics = self.critic.update_critic(dataloader=dataloader)
|
||||
output = DataProto(batch=None, meta_info={'metrics': metrics})
|
||||
output = output.to('cpu')
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def save_checkpoint(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
|
||||
class RewardModelWorker(MegatronWorker):
|
||||
"""
|
||||
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# NOTE(sgm): We utilize colocate WorkerGroup by default.
|
||||
# As a result, Workers for different model share the same process.
|
||||
# Therefore, we only require one distribute initialization.
|
||||
# To utilize different parallel startegy in different models:
|
||||
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
|
||||
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
|
||||
if not torch.distributed.is_initialized():
|
||||
rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if self.config.megatron.sequence_parallel:
|
||||
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
|
||||
mpu.initialize_model_parallel(
|
||||
tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=None,
|
||||
pipeline_model_parallel_split_rank=None,
|
||||
use_sharp=False,
|
||||
context_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
nccl_communicator_config_path=None,
|
||||
)
|
||||
|
||||
set_random_seed(seed=self.config.megatron.seed)
|
||||
|
||||
# normalize config
|
||||
self.config.micro_batch_size //= mpu.get_data_parallel_world_size()
|
||||
|
||||
def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, override_model_config):
|
||||
from megatron.core.models.gpt.gpt_model import ModelType
|
||||
from verl.utils.model import print_model_size, update_model_config
|
||||
from verl.utils.megatron_utils import get_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
|
||||
# Step 1: initialize the tokenizer
|
||||
local_path = copy_local_path_from_hdfs(model_path)
|
||||
self.tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# Step 2: get the actor_model_config
|
||||
rm_model_config = AutoConfig.from_pretrained(local_path)
|
||||
|
||||
override_config_kwargs = {
|
||||
'bos_token_id': self.tokenizer.bos_token_id,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
}
|
||||
override_config_kwargs.update(override_model_config)
|
||||
update_model_config(rm_model_config, override_config_kwargs=override_config_kwargs)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f'Model config after override: {rm_model_config}')
|
||||
|
||||
def megatron_rm_model_provider(pre_process, post_process):
|
||||
from verl.utils.model import get_parallel_model_from_config
|
||||
# vpp is not supported yet because it will hang for some reason. Need debugging
|
||||
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
|
||||
# this_megatron_config = copy.deepcopy(megatron_config)
|
||||
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
|
||||
parallel_model = get_parallel_model_from_config(config=rm_model_config,
|
||||
megatron_config=megatron_config,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
value=True)
|
||||
parallel_model.cuda()
|
||||
return parallel_model
|
||||
|
||||
# Step 3: initialize the megatron model
|
||||
reward_model = get_model(model_provider_func=megatron_rm_model_provider,
|
||||
model_type=ModelType.encoder_or_decoder,
|
||||
wrap_with_ddp=False)
|
||||
# note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).
|
||||
# but here, we do not use pp (vpp) yet. For simplicity, we remove the list
|
||||
# reward_model = nn.ModuleList(reward_model)
|
||||
|
||||
if self.config.load_weight:
|
||||
load_megatron_model_weights(self.config,
|
||||
rm_model_config,
|
||||
reward_model,
|
||||
params_dtype=megatron_config.params_dtype,
|
||||
is_value_model=True)
|
||||
|
||||
# TODO: add more optimizer args into config
|
||||
torch.cuda.empty_cache()
|
||||
return reward_model, rm_model_config
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
# create critic
|
||||
from omegaconf import OmegaConf
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.config.model.get('external_lib', None) is not None:
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import importlib
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
|
||||
|
||||
sft_tokenizer_local_path = copy_local_path_from_hdfs(self.config.model.input_tokenizer)
|
||||
sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path)
|
||||
rm_tokenizer_path = self.config.model.get('rm_tokenizer', None)
|
||||
rm_tokenizer = None
|
||||
if rm_tokenizer_path is not None:
|
||||
rm_tokenizer_local_path = copy_local_path_from_hdfs(rm_tokenizer_path)
|
||||
rm_tokenizer = hf_tokenizer(rm_tokenizer_local_path)
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
megatron_config = OmegaConf.create({
|
||||
'sequence_parallel': self.config.megatron.get('sequence_parallel', True),
|
||||
'param_dtype': PrecisionType.to_str(torch_dtype),
|
||||
'tensor_model_parallel_size': mpu.get_tensor_model_parallel_world_size(),
|
||||
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
|
||||
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
|
||||
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
|
||||
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
|
||||
})
|
||||
|
||||
megatron_config = init_model_parallel_config(megatron_config)
|
||||
|
||||
reward_model_module, reward_model_config = self._build_rm_model(
|
||||
model_path=self.config.model.path,
|
||||
megatron_config=megatron_config,
|
||||
override_model_config=override_model_config,
|
||||
)
|
||||
# FIXME(sgm): reward model param offload is implemented in MegatronRewardModel
|
||||
# should be implemented in workers
|
||||
self.rm = MegatronRewardModel(config=self.config,
|
||||
reward_model_module=reward_model_module,
|
||||
model_config=reward_model_config,
|
||||
megatron_config=megatron_config,
|
||||
sft_tokenizer=sft_tokenizer,
|
||||
rm_tokenizer=rm_tokenizer)
|
||||
|
||||
# TODO: reward model use itself tokenizer instead of sft tokenizer
|
||||
# the input_ids, responses, attention_mask and position_ids may be different!
|
||||
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
|
||||
def compute_rm_score(self, data: DataProto):
|
||||
data.batch = data.batch.cuda()
|
||||
output = self.rm.compute_reward(data)
|
||||
output = output.to('cpu')
|
||||
return output
|
||||
383
verl/workers/retriever_workers.py
Normal file
383
verl/workers/retriever_workers.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import os
|
||||
import torch
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.base.decorator import register, Dispatch
|
||||
# from search_r1.env.search.retrieval import get_retriever
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Dict
|
||||
import functools
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
import faiss
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||
import argparse
|
||||
import datasets
|
||||
|
||||
|
||||
def load_corpus(corpus_path: str):
|
||||
corpus = datasets.load_dataset(
|
||||
'json',
|
||||
data_files=corpus_path,
|
||||
split="train",
|
||||
num_proc=4)
|
||||
return corpus
|
||||
|
||||
|
||||
def read_jsonl(file_path):
|
||||
data = []
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
readin = f.readlines()
|
||||
for line in readin:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
def load_docs(corpus, doc_idxs):
|
||||
results = [corpus[int(idx)] for idx in doc_idxs]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str,
|
||||
use_fp16: bool = False
|
||||
):
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model = model.half()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pooling(
|
||||
pooler_output,
|
||||
last_hidden_state,
|
||||
attention_mask = None,
|
||||
pooling_method = "mean"
|
||||
):
|
||||
if pooling_method == "mean":
|
||||
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif pooling_method == "cls":
|
||||
return last_hidden_state[:, 0]
|
||||
elif pooling_method == "pooler":
|
||||
return pooler_output
|
||||
else:
|
||||
raise NotImplementedError("Pooling method not implemented!")
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.pooling_method = pooling_method
|
||||
self.max_length = max_length
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
self.model, self.tokenizer = load_model(model_path=model_path,
|
||||
use_fp16=use_fp16)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||
# processing query for different encoders
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
|
||||
if "e5" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"query: {query}" for query in query_list]
|
||||
else:
|
||||
query_list = [f"passage: {query}" for query in query_list]
|
||||
|
||||
if "bge" in self.model_name.lower():
|
||||
if is_query:
|
||||
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
|
||||
|
||||
inputs = self.tokenizer(query_list,
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
if "T5" in type(self.model).__name__:
|
||||
# T5-based retrieval model
|
||||
decoder_input_ids = torch.zeros(
|
||||
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||
).to(inputs['input_ids'].device)
|
||||
output = self.model(
|
||||
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||
)
|
||||
query_emb = output.last_hidden_state[:, 0, :]
|
||||
|
||||
else:
|
||||
output = self.model(**inputs, return_dict=True)
|
||||
query_emb = pooling(output.pooler_output,
|
||||
output.last_hidden_state,
|
||||
inputs['attention_mask'],
|
||||
self.pooling_method)
|
||||
if "dpr" not in self.model_name.lower():
|
||||
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
||||
|
||||
query_emb = query_emb.detach().cpu().numpy()
|
||||
query_emb = query_emb.astype(np.float32, order="C")
|
||||
return query_emb
|
||||
|
||||
|
||||
class BaseRetriever:
|
||||
"""Base object for all retrievers."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.retrieval_method = config.retrieval_method
|
||||
self.topk = config.retrieval_topk
|
||||
|
||||
self.index_path = config.index_path
|
||||
self.corpus_path = config.corpus_path
|
||||
|
||||
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
|
||||
|
||||
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
|
||||
r"""Retrieve topk relevant documents in corpus.
|
||||
Return:
|
||||
list: contains information related to the document, including:
|
||||
contents: used for building index
|
||||
title: (if provided)
|
||||
text: (if provided)
|
||||
"""
|
||||
pass
|
||||
|
||||
def _batch_search(self, query_list, num, return_score):
|
||||
pass
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return self._search(*args, **kwargs)
|
||||
|
||||
def batch_search(self, *args, **kwargs):
|
||||
return self._batch_search(*args, **kwargs)
|
||||
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
r"""BM25 retriever based on pre-built pyserini index."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
raise NotImplementedError
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
self.searcher = LuceneSearcher(self.index_path)
|
||||
self.contain_doc = self._check_contain_doc()
|
||||
if not self.contain_doc:
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.max_process_num = 8
|
||||
|
||||
def _check_contain_doc(self):
|
||||
r"""Check if the index contains document content
|
||||
"""
|
||||
return self.searcher.doc(0).raw() is not None
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
|
||||
if num is None:
|
||||
num = self.topk
|
||||
|
||||
hits = self.searcher.search(query, num)
|
||||
if len(hits) < 1:
|
||||
if return_score:
|
||||
return [],[]
|
||||
else:
|
||||
return []
|
||||
|
||||
scores = [hit.score for hit in hits]
|
||||
if len(hits) < num:
|
||||
warnings.warn('Not enough documents retrieved!')
|
||||
else:
|
||||
hits = hits[:num]
|
||||
|
||||
if self.contain_doc:
|
||||
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
|
||||
results = [{'title': content.split("\n")[0].strip("\""),
|
||||
'text': "\n".join(content.split("\n")[1:]),
|
||||
'contents': content} for content in all_contents]
|
||||
else:
|
||||
results = load_docs(self.corpus, [hit.docid for hit in hits])
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list, num: int = None, return_score = False):
|
||||
# TODO: modify batch method
|
||||
results = []
|
||||
scores = []
|
||||
for query in query_list:
|
||||
item_result, item_score = self._search(query, num,True)
|
||||
results.append(item_result)
|
||||
scores.append(item_score)
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
class DenseRetriever(BaseRetriever):
|
||||
r"""Dense retriever based on pre-built faiss index."""
|
||||
|
||||
def __init__(self, config: dict, index):
|
||||
super().__init__(config)
|
||||
self.index = index
|
||||
# self.index = faiss.read_index(self.index_path)
|
||||
# if config.faiss_gpu:
|
||||
# co = faiss.GpuMultipleClonerOptions()
|
||||
# co.useFloat16 = True
|
||||
# co.shard = True
|
||||
# self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
|
||||
# # self.index = faiss.index_cpu_to_all_gpus(self.index)
|
||||
|
||||
self.corpus = load_corpus(self.corpus_path)
|
||||
self.encoder = Encoder(
|
||||
model_name = self.retrieval_method,
|
||||
model_path = config.retrieval_model_path,
|
||||
pooling_method = config.retrieval_pooling_method,
|
||||
max_length = config.retrieval_query_max_length,
|
||||
use_fp16 = config.retrieval_use_fp16
|
||||
)
|
||||
self.topk = config.retrieval_topk
|
||||
self.batch_size = self.config.retrieval_batch_size
|
||||
|
||||
def _search(self, query: str, num: int = None, return_score = False):
|
||||
raise NotImplementedError
|
||||
if num is None:
|
||||
num = self.topk
|
||||
query_emb = self.encoder.encode(query)
|
||||
scores, idxs = self.index.search(query_emb, k=num)
|
||||
idxs = idxs[0]
|
||||
scores = scores[0]
|
||||
|
||||
results = load_docs(self.corpus, idxs)
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
if num is None:
|
||||
num = self.topk
|
||||
|
||||
batch_size = self.batch_size
|
||||
|
||||
results = []
|
||||
scores = []
|
||||
|
||||
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
|
||||
query_batch = query_list[start_idx:start_idx + batch_size]
|
||||
|
||||
# from time import time
|
||||
# a = time()
|
||||
batch_emb = self.encoder.encode(query_batch)
|
||||
# b = time()
|
||||
# print(f'################### encode time {b-a} #####################')
|
||||
batch_scores, batch_idxs = ray.get(self.index.batch_search.remote(batch_emb, k=num))
|
||||
batch_scores = batch_scores.tolist()
|
||||
batch_idxs = batch_idxs.tolist()
|
||||
# print(f'################### search time {time()-b} #####################')
|
||||
# exit()
|
||||
|
||||
flat_idxs = sum(batch_idxs, [])
|
||||
batch_results = load_docs(self.corpus, flat_idxs)
|
||||
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
|
||||
|
||||
scores.extend(batch_scores)
|
||||
results.extend(batch_results)
|
||||
|
||||
if return_score:
|
||||
return results, scores
|
||||
else:
|
||||
return results
|
||||
|
||||
def get_retriever(config, index):
|
||||
r"""Automatically select retriever class based on config's retrieval method
|
||||
|
||||
Args:
|
||||
config (dict): configuration with 'retrieval_method' key
|
||||
|
||||
Returns:
|
||||
Retriever: retriever instance
|
||||
"""
|
||||
if config.retrieval_method == "bm25":
|
||||
raise NotImplementedError
|
||||
return BM25Retriever(config)
|
||||
else:
|
||||
return DenseRetriever(config, index)
|
||||
|
||||
|
||||
|
||||
class RetrieveWorker(Worker):
|
||||
"""Environment worker that handles GPU-based environment operations."""
|
||||
|
||||
def __init__(self, config, faiss_server):
|
||||
super().__init__()
|
||||
config.index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index') if config.retrieval_method != 'bm25' else os.path.join(config.index_path, 'bm25')
|
||||
|
||||
self.config = config # Initialize environment later
|
||||
self.faiss_server = faiss_server
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
self.retriever = get_retriever(self.config, self.faiss_server)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@register(dispatch_mode=Dispatch.ALL_TO_ALL)
|
||||
def batch_search(self, queries):
|
||||
return self.retriever.batch_search(queries)
|
||||
|
||||
|
||||
import ray
|
||||
import faiss
|
||||
import torch
|
||||
|
||||
@ray.remote(num_gpus=8) # Allocate all GPUs
|
||||
class FAISSIndexServer:
|
||||
"""Ray Actor that loads and serves a shared FAISS index with FAISS GPU optimization."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize the FAISS index only once."""
|
||||
print("[FAISSIndexServer] Loading FAISS index...")
|
||||
self.config = config
|
||||
self.index = self.load_index(config)
|
||||
|
||||
def load_index(self, config):
|
||||
"""Loads the FAISS index into GPU memory with sharding."""
|
||||
index_path = os.path.join(config.index_path, f'{config.retrieval_method}_Flat.index')
|
||||
index = faiss.read_index(index_path)
|
||||
|
||||
if self.config.faiss_gpu:
|
||||
|
||||
# Apply FAISS GPU settings
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.useFloat16 = True # Reduce memory footprint
|
||||
co.shard = True # Distribute index across all GPUs
|
||||
|
||||
print("[FAISSIndexServer] Moving FAISS index to all GPUs with sharding enabled...")
|
||||
index = faiss.index_cpu_to_all_gpus(index, co=co)
|
||||
|
||||
print("[FAISSIndexServer] FAISS index successfully moved to GPUs.")
|
||||
return index
|
||||
|
||||
def batch_search(self, batch_emb, k):
|
||||
"""Perform batch search on the FAISS index."""
|
||||
print(f"[FAISSIndexServer] Received {len(batch_emb)} queries.")
|
||||
return self.index.search(batch_emb, k) # Adjust 'k' as needed
|
||||
15
verl/workers/reward_model/__init__.py
Normal file
15
verl/workers/reward_model/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BasePPORewardModel
|
||||
45
verl/workers/reward_model/base.py
Normal file
45
verl/workers/reward_model/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The base class for reward model
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class BasePPORewardModel(ABC):
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def compute_reward(self, data: DataProto) -> DataProto:
|
||||
"""Computing reward given input_ids. The transformers should output a tensor with shape
|
||||
[batch_size, sequence_length], and the value at [EOS] mask should be gathered.
|
||||
|
||||
Args:
|
||||
data: must contain keys "input_ids", "attention_mask" and "position_ids".
|
||||
- input_ids: [batch_size, sequence_length]
|
||||
- attention_mask: [batch_size, sequence_length]
|
||||
- position_ids: [batch_size, sequence_length]
|
||||
|
||||
Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward.
|
||||
Other position should have zero reward. Note that this may change in the future if we use
|
||||
dense reward. So, we leave the interface for general case.
|
||||
- reward: [batch_size, sequence_length].
|
||||
|
||||
"""
|
||||
pass
|
||||
15
verl/workers/reward_model/megatron/__init__.py
Normal file
15
verl/workers/reward_model/megatron/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .reward_model import MegatronRewardModel
|
||||
278
verl/workers/reward_model/megatron/reward_model.py
Normal file
278
verl/workers/reward_model/megatron/reward_model.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Megatron Reward Model.
|
||||
"""
|
||||
|
||||
from tensordict import TensorDict
|
||||
from functools import partial
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
import torch
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
|
||||
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
from verl.workers.reward_model.base import BasePPORewardModel
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core.pipeline_parallel import get_forward_backward_func
|
||||
|
||||
|
||||
class MegatronRewardModel(BasePPORewardModel):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
model_config,
|
||||
reward_model_module: torch.nn.ModuleList,
|
||||
megatron_config,
|
||||
sft_tokenizer=None,
|
||||
rm_tokenizer=None):
|
||||
self.config = config
|
||||
self.reward_model_module = reward_model_module
|
||||
self.megatron_config = megatron_config
|
||||
self.model_config = model_config
|
||||
self.device = 'cuda'
|
||||
self.sft_tokenizer = sft_tokenizer
|
||||
self.rm_tokenizer = rm_tokenizer
|
||||
self.use_different_tokenizer = rm_tokenizer is not None
|
||||
|
||||
if self.config.param_offload:
|
||||
self.offload_params_to_cpu()
|
||||
|
||||
def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:
|
||||
assert self.use_different_tokenizer, 're-encode need rm tokenizer not be None!'
|
||||
# need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids
|
||||
# 1. remove pad for each sequence
|
||||
# 2. decode by sft_tokenizer, remove sft system prompts
|
||||
# 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids
|
||||
# 4. generate attention_mask and position_ids
|
||||
input_ids = data.batch['input_ids'] # (bs, seq_len)
|
||||
attention_mask = data.batch['attention_mask']
|
||||
position_ids = data.batch['position_ids']
|
||||
ori_values = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
|
||||
ori_bs, ori_seqlen = input_ids.size(0), input_ids.size(1)
|
||||
input_ids_for_rm = []
|
||||
attention_mask_for_rm = []
|
||||
position_ids_for_rm = []
|
||||
print_decode = True
|
||||
ori_seqlen = ori_seqlen + 128
|
||||
for id, mask in zip(input_ids, attention_mask):
|
||||
# 1. remove pad for each sequence
|
||||
non_zero_indices = torch.nonzero(mask).view(-1)
|
||||
begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item()
|
||||
valid_id = id[begin_pos:end_pos + 1]
|
||||
# 2. decode by sft_tokenizer, remove sft system prompts
|
||||
decode_result = self.sft_tokenizer.decode(valid_id)
|
||||
# workaround
|
||||
decode_with_rm_chat = decode_result.replace("<|user|>\n", "[INST] ").replace(
|
||||
"</s>\n<|assistant|>\n", " [/INST]").replace("</s> \n<|assistant|>\n", " [/INST]") + "</s>"
|
||||
|
||||
print(f"decode_with_rm_chat: {decode_with_rm_chat}")
|
||||
|
||||
if print_decode and torch.distributed.get_rank() == 0:
|
||||
# only print first decode result
|
||||
print(f'device {torch.cuda.current_device()}: sft decode result:\n{decode_result}\n \
|
||||
\ndevice {torch.cuda.current_device()}: sft decode result with rm chat template:\n{decode_with_rm_chat}\n\n'
|
||||
)
|
||||
print_decode = False
|
||||
# 3. encode by rm_tokenizer
|
||||
rm_input_ids = self.rm_tokenizer(decode_with_rm_chat,
|
||||
return_tensors='pt')['input_ids'][0].to(input_ids.device)
|
||||
# 4. generate attention_mask and position_ids
|
||||
rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device)
|
||||
cur_seqlen = rm_input_ids.shape[-1]
|
||||
# NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128)
|
||||
if cur_seqlen > ori_seqlen:
|
||||
print(f'warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}')
|
||||
rm_input_ids = rm_input_ids[:ori_seqlen]
|
||||
rm_attention_mask = rm_attention_mask[:ori_seqlen]
|
||||
else:
|
||||
# right padding
|
||||
rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)
|
||||
rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)
|
||||
rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)
|
||||
input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))
|
||||
attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0))
|
||||
position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0))
|
||||
input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0)
|
||||
attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0)
|
||||
position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0)
|
||||
|
||||
# (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change
|
||||
# NOTE(gh): need to replace into origin values after compute reward!
|
||||
data.batch['input_ids'] = input_ids_for_rm
|
||||
data.batch['attention_mask'] = attention_mask_for_rm
|
||||
data.batch['position_ids'] = position_ids_for_rm
|
||||
|
||||
return data, ori_values
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_reward(self, data: DataProto) -> DataProto:
|
||||
if self.config.param_offload:
|
||||
self.load_params_to_cuda()
|
||||
|
||||
if self.use_different_tokenizer:
|
||||
data, ori_values = self.re_encode_by_rm_tokenizer(data)
|
||||
|
||||
input_ids = data.batch['input_ids'] # (bs, seq_len')
|
||||
attention_mask = data.batch['attention_mask']
|
||||
position_ids = data.batch['position_ids']
|
||||
|
||||
responses = data.batch['responses']
|
||||
batch_size = responses.size(0)
|
||||
response_length = responses.size(1)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.forward_batch(data)
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
logits = torch.cat([o['logits'] for o in output], dim=0)
|
||||
else:
|
||||
logits = torch.empty(
|
||||
(input_ids.shape[0], input_ids.shape[1]),
|
||||
dtype=torch.bfloat16, # TODO(sgm): check why is bfloat16
|
||||
device=input_ids.device)
|
||||
# broadcast across pp ranks
|
||||
torch.distributed.broadcast(tensor=logits,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group(),
|
||||
async_op=False)
|
||||
|
||||
# (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen')
|
||||
token_level_rewards = logits
|
||||
# find the last token reward
|
||||
ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1) # (bs, 1)
|
||||
rewards = torch.gather(token_level_rewards, dim=1, index=ends) # (bs, 1)
|
||||
|
||||
if self.use_different_tokenizer:
|
||||
data.batch.update(ori_values)
|
||||
input_ids = ori_values['input_ids']
|
||||
attention_mask = ori_values['attention_mask']
|
||||
position_ids = ori_values['position_ids']
|
||||
|
||||
token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen)
|
||||
|
||||
# assign last valid token reward to ori position
|
||||
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bs,)
|
||||
eos_mask = torch.zeros_like(attention_mask)
|
||||
eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.
|
||||
|
||||
token_level_rewards = token_level_rewards * eos_mask
|
||||
token_level_rewards = token_level_rewards[:, -response_length:]
|
||||
|
||||
if self.config.param_offload:
|
||||
self.offload_params_to_cpu()
|
||||
else:
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
batch = TensorDict({'rm_scores': token_level_rewards}, batch_size=input_ids.shape[0])
|
||||
|
||||
return DataProto(batch=batch)
|
||||
|
||||
def forward_batch(self, data: DataProto):
|
||||
"""
|
||||
We assume:
|
||||
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
|
||||
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
|
||||
"""
|
||||
# broadcast from last pp rank to all other pp ranks
|
||||
# TODO: actually, we just need to control the sampling order.
|
||||
data.batch = data.batch.contiguous()
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_pipeline_model_parallel_last_rank(),
|
||||
group=mpu.get_pipeline_model_parallel_group())
|
||||
|
||||
# split into micro-batches
|
||||
if self.config is not None and 'ppo_micro_batch_size' in self.config:
|
||||
infer_batch_size = self.config.ppo_micro_batch_size
|
||||
else:
|
||||
infer_batch_size = data.batch.batch_size[0]
|
||||
|
||||
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
|
||||
batches = split_dict_tensor_into_batches(data.batch, batch_size=infer_batch_size)
|
||||
n_micro_batch = len(batches)
|
||||
seq_len = batches[0]['input_ids'].shape[1]
|
||||
|
||||
# compute input shapes for pp stages
|
||||
input_shapes = compute_transformers_input_shapes(
|
||||
batches,
|
||||
meta_info={
|
||||
'sequence_parallel': self.megatron_config.sequence_parallel,
|
||||
'hidden_size': self.model_config.hidden_size
|
||||
})
|
||||
# compute input shapes for pp stages
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
|
||||
def loss_func(output):
|
||||
return 1., {'logits': output.logits}
|
||||
|
||||
def forward_step(batch_iter, model):
|
||||
batch = next(batch_iter)
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
position_ids = batch['position_ids']
|
||||
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
return output, loss_func
|
||||
|
||||
# batch should be a list of batches inside micro-batches
|
||||
batch_generator = make_batch_generator(batches, vpp_size=len(self.reward_model_module))
|
||||
|
||||
# TODO: we may use the new schedule instead
|
||||
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.reward_model_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
input_shapes=input_shapes, # must set for flash-attn sequence packing
|
||||
seq_length=infer_batch_size * seq_len, # no use when input_shapes was set
|
||||
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
|
||||
micro_batch_size=1, # no use when input_shapes was set
|
||||
forward_only=True,
|
||||
)
|
||||
else:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step,
|
||||
data_iterator=batch_generator,
|
||||
model=self.reward_model_module,
|
||||
num_microbatches=n_micro_batch,
|
||||
seq_length=infer_batch_size * seq_len, # in use for pp = 1
|
||||
hidden_size=self.model_config.hidden_size, # in use for pp = 1
|
||||
micro_batch_size=1, # in use for pp = 1
|
||||
forward_only=True,
|
||||
)
|
||||
# loss_reduces contains the stats returned from loss_func
|
||||
|
||||
return losses_reduced
|
||||
|
||||
def offload_params_to_cpu(self):
|
||||
if self.device == 'cuda':
|
||||
for reward_model_module in self.reward_model_module:
|
||||
for name, param in reward_model_module.named_parameters():
|
||||
param.data = param.data.to('cpu', non_blocking=True)
|
||||
self.device = 'cpu'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_params_to_cuda(self):
|
||||
if self.device == 'cpu':
|
||||
for reward_model_module in self.reward_model_module:
|
||||
for name, param in reward_model_module.named_parameters():
|
||||
param.data = param.data.to(torch.cuda.current_device(), non_blocking=True)
|
||||
self.device = 'cuda'
|
||||
19
verl/workers/rollout/__init__.py
Normal file
19
verl/workers/rollout/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BaseRollout
|
||||
from .naive import NaiveRollout
|
||||
from .hf_rollout import HFRollout
|
||||
|
||||
__all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]
|
||||
37
verl/workers/rollout/base.py
Normal file
37
verl/workers/rollout/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, Union
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
__all__ = ['BaseRollout']
|
||||
|
||||
|
||||
class BaseRollout(ABC):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
|
||||
should handle when the training stops.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
"""Generate sequences"""
|
||||
pass
|
||||
140
verl/workers/rollout/hf_rollout.py
Normal file
140
verl/workers/rollout/hf_rollout.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Rollout with huggingface models.
|
||||
TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single GPU model.
|
||||
Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model to perform generation.
|
||||
"""
|
||||
import contextlib
|
||||
import torch
|
||||
import torch.distributed
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask
|
||||
from .base import BaseRollout
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
__all__ = ['HFRollout']
|
||||
|
||||
|
||||
class HFRollout(BaseRollout):
|
||||
|
||||
def __init__(self, module: nn.Module, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = module
|
||||
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
batch_size = prompts.batch.batch_size[0]
|
||||
num_chunks = max(batch_size // self.config.get('micro_batch_size', batch_size), 1)
|
||||
batch_prompts = prompts.chunk(chunks=num_chunks)
|
||||
output = [self._generate_minibatch(p) for p in batch_prompts]
|
||||
output = DataProto.concat(output)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_minibatch(self, prompts: DataProto) -> DataProto:
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
pad_token_id = prompts.meta_info['pad_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
prompt_length = idx.size(1)
|
||||
|
||||
self.module.eval()
|
||||
param_ctx = contextlib.nullcontext()
|
||||
|
||||
# make sampling args can be overriden by inputs
|
||||
do_sample = prompts.meta_info.get('do_sample', self.config.do_sample)
|
||||
response_length = prompts.meta_info.get('response_length', self.config.response_length)
|
||||
top_p = prompts.meta_info.get('top_p', self.config.get('top_p', 1.0))
|
||||
top_k = prompts.meta_info.get('top_k', self.config.get('top_k', 0))
|
||||
|
||||
if top_k is None:
|
||||
top_k = 0
|
||||
top_k = max(0, top_k) # to be compatible with vllm
|
||||
|
||||
temperature = prompts.meta_info.get('temperature', self.config.temperature)
|
||||
|
||||
generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
|
||||
if isinstance(self.module, FSDP):
|
||||
# recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069
|
||||
param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)
|
||||
with param_ctx:
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.module.generate(
|
||||
input_ids=idx,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=response_length,
|
||||
# max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
generation_config=generation_config,
|
||||
# renormalize_logits=True,
|
||||
output_scores=False, # this is potentially very large
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True)
|
||||
# TODO: filter out the seq with no answers like ds-chat
|
||||
seq = output.sequences
|
||||
|
||||
# huggingface generate will stop generating when all the batch reaches [EOS].
|
||||
# We have to pad to response_length
|
||||
sequence_length = prompt_length + self.config.response_length
|
||||
delta_length = sequence_length - seq.shape[1]
|
||||
|
||||
if delta_length > 0:
|
||||
delta_tokens = torch.ones(size=(batch_size, delta_length), device=seq.device, dtype=seq.dtype)
|
||||
delta_tokens = pad_token_id * delta_tokens
|
||||
seq = torch.cat((seq, delta_tokens), dim=1)
|
||||
|
||||
assert seq.shape[1] == sequence_length
|
||||
|
||||
prompt = seq[:, :prompt_length] # (bs, prompt_length)
|
||||
response = seq[:, prompt_length:] # (bs, response_length)
|
||||
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
batch = TensorDict(
|
||||
{
|
||||
'prompts': prompt,
|
||||
'responses': response,
|
||||
'input_ids': seq,
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
# empty cache before compute old_log_prob
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.module.train()
|
||||
return DataProto(batch=batch)
|
||||
15
verl/workers/rollout/naive/__init__.py
Normal file
15
verl/workers/rollout/naive/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .naive_rollout import NaiveRollout
|
||||
119
verl/workers/rollout/naive/naive_rollout.py
Normal file
119
verl/workers/rollout/naive/naive_rollout.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
In single GPU rollout, the sequences are generated directly by sampling from the model.
|
||||
The output will contain
|
||||
1. output_ids
|
||||
2. attention_masks (left padding)
|
||||
3. eos_masks
|
||||
4. log_probs
|
||||
"""
|
||||
from typing import Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
from ..base import BaseRollout
|
||||
|
||||
__all__ = ['NativeRollout']
|
||||
|
||||
|
||||
class NaiveRollout(BaseRollout):
|
||||
|
||||
def __init__(self, module: nn.Module, config):
|
||||
"""A naive rollout. It requires the module to be compatible with huggingface APIs. That is:
|
||||
The module should define __call__ to receive input_ids, attention_mask and position_ids.
|
||||
It outputs a structure that contains logits field.
|
||||
|
||||
Args:
|
||||
module: module here follows huggingface APIs
|
||||
config: DictConfig
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = module
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
"""Generate sequences"""
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
prompt_length = idx.size(1)
|
||||
|
||||
self.module.eval()
|
||||
|
||||
prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
logits_lst = []
|
||||
for _ in range(self.config.response_length):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
||||
idx_cond = idx
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
# we use huggingface APIs here
|
||||
output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)
|
||||
logits = output.logits
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size)
|
||||
# optionally crop the logits to only the top k options
|
||||
if self.config.top_k is not None:
|
||||
v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution
|
||||
if self.config.do_sample:
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
||||
|
||||
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
|
||||
|
||||
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
|
||||
prev_attention_mask.to(attention_mask.dtype)
|
||||
|
||||
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
|
||||
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
logits_lst.append(logits)
|
||||
|
||||
logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size)
|
||||
prompts = idx[:, :prompt_length] # (bs, prompt_length)
|
||||
response = idx[:, prompt_length:] # (bs, response_length)
|
||||
log_probs = logprobs_from_logits(logits=logits, labels=response)
|
||||
batch = TensorDict(
|
||||
{
|
||||
'input_ids': prompts,
|
||||
'responses': response,
|
||||
'sequences': idx,
|
||||
'old_log_probs': log_probs,
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids,
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
self.module.train()
|
||||
|
||||
return DataProto(batch=batch)
|
||||
162
verl/workers/rollout/tokenizer.py
Normal file
162
verl/workers/rollout/tokenizer.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
|
||||
__all__ = ['HybridEngineBaseTokenizer']
|
||||
|
||||
|
||||
class HybridEngineBaseTokenizer(ABC):
|
||||
"""the tokenizer property and function name should align with HF's to meet vllm requirement"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
"""
|
||||
`int`: Size of the base vocabulary (without the added tokens).
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pad_token_id(self):
|
||||
"""
|
||||
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eos_token_id(self):
|
||||
"""
|
||||
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
|
||||
set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_ids(self) -> List[int]:
|
||||
"""
|
||||
`List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
"""
|
||||
`List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).
|
||||
|
||||
Convert tokens of `tokenizers.AddedToken` type to string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, text):
|
||||
"""
|
||||
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]` or `List[int]`):
|
||||
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
|
||||
`tokenize` method) or a list of integers.
|
||||
|
||||
text_pair (`str`, `List[str]` or `List[int]`, *optional*):
|
||||
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
|
||||
the `tokenize` method) or a list of integers.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||
tokens and clean up tokenization spaces.
|
||||
|
||||
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||
|
||||
Args:
|
||||
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||
`self.clean_up_tokenization_spaces`.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_ids_to_tokens(self,
|
||||
ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False) -> Union[str, List[str]]:
|
||||
"""
|
||||
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
|
||||
added tokens.
|
||||
|
||||
Args:
|
||||
ids (`int` or `List[int]`):
|
||||
The token id (or token ids) to convert to tokens.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
|
||||
Returns:
|
||||
`str` or `List[str]`: The decoded token(s).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
|
||||
the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
|
||||
something we should change.
|
||||
|
||||
Returns:
|
||||
`Dict[str, int]`: The added tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""
|
||||
Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
|
||||
often want to remove sub-word tokenization artifacts at the same time.
|
||||
|
||||
Args:
|
||||
tokens (`List[str]`): The token to join in a string.
|
||||
|
||||
Returns:
|
||||
`str`: The joined tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_fast(self):
|
||||
return False
|
||||
15
verl/workers/rollout/vllm_rollout/__init__.py
Normal file
15
verl/workers/rollout/vllm_rollout/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .vllm_rollout import vLLMRollout
|
||||
226
verl/workers/rollout/vllm_rollout/vllm_rollout.py
Normal file
226
verl/workers/rollout/vllm_rollout/vllm_rollout.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The vllm_rollout that can be applied in different backend
|
||||
When working with FSDP:
|
||||
- Use DTensor weight loader (recommended) or HF weight loader
|
||||
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
|
||||
When working with Megatron:
|
||||
- Use Megatron weight loader
|
||||
- During training, only the current pp stage holds the parameters
|
||||
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
|
||||
- Bind the parameters to the inference engine
|
||||
- Do inference in tp. pp is treated as additional dp
|
||||
- After inference, all the parameters that doesn't belong to this pp rank is freed.
|
||||
"""
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from omegaconf import DictConfig
|
||||
import torch
|
||||
import torch.distributed
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.third_party.vllm import LLM, vllm_version
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from vllm import SamplingParams
|
||||
|
||||
# TODO
|
||||
# 1. support pp in vllm
|
||||
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
|
||||
# 3. simplify init logics
|
||||
|
||||
|
||||
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
|
||||
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
|
||||
# remove the left padding in the prompt token_id
|
||||
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||
return token_ids
|
||||
|
||||
|
||||
class vLLMRollout(BaseRollout):
|
||||
|
||||
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
|
||||
"""A vLLM rollout. It requires the module is supported by the vllm.
|
||||
|
||||
Args:
|
||||
module: module here follows huggingface APIs
|
||||
config: DictConfig
|
||||
tokenizer: the task/model tokenizer
|
||||
model_hf_config: the huggingface config to initiallize the generating model in vllm
|
||||
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not (not config.enforce_eager and config.free_cache_engine), \
|
||||
"disable CUDA graph (enforce_eager = False) if free cache engine"
|
||||
|
||||
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
|
||||
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
|
||||
"tensor parallel size should be less than or equal to the world size"
|
||||
|
||||
if kwargs.get('train_tp', None) is not None:
|
||||
# deployed with megatron
|
||||
import os
|
||||
os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
|
||||
os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
|
||||
train_tp = kwargs.get('train_tp', None)
|
||||
num_tp_per_train_tp = train_tp // tensor_parallel_size
|
||||
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
|
||||
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
|
||||
num_tp_per_train_tp=num_tp_per_train_tp)
|
||||
|
||||
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
|
||||
"model context length should be greater than total sequence length"
|
||||
self.inference_engine = LLM(actor_module,
|
||||
tokenizer=tokenizer,
|
||||
model_hf_config=model_hf_config,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=config.dtype,
|
||||
enforce_eager=config.enforce_eager,
|
||||
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||
skip_tokenizer_init=False,
|
||||
max_model_len=config.prompt_length + config.response_length,
|
||||
load_format=config.load_format)
|
||||
|
||||
# Offload vllm model to reduce peak memory usage
|
||||
self.inference_engine.offload_model_weights()
|
||||
|
||||
kwargs = dict(
|
||||
n=1,
|
||||
logprobs=1, # can be set to 0 and let actor to recompute
|
||||
max_tokens=config.response_length,
|
||||
)
|
||||
|
||||
# we may detokenize the result all together later
|
||||
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
|
||||
kwargs['detokenize'] = False
|
||||
|
||||
# supporting adding any sampling params from the config file
|
||||
for k in config.keys():
|
||||
if hasattr(SamplingParams(), str(k)):
|
||||
kwargs[k] = config.get(k)
|
||||
|
||||
print(f"kwargs: {kwargs}")
|
||||
self.sampling_params = SamplingParams(**kwargs)
|
||||
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
@contextmanager
|
||||
def update_sampling_params(self, **kwargs):
|
||||
# update sampling params
|
||||
old_sampling_params_args = {}
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.sampling_params, key):
|
||||
old_value = getattr(self.sampling_params, key)
|
||||
old_sampling_params_args[key] = old_value
|
||||
setattr(self.sampling_params, key, value)
|
||||
yield
|
||||
# roll back to previous sampling params
|
||||
# if len(old_sampling_params_args):
|
||||
for key, value in old_sampling_params_args.items():
|
||||
setattr(self.sampling_params, key, value)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
|
||||
# rebuild vllm cache engine
|
||||
if self.config.free_cache_engine:
|
||||
self.inference_engine.init_cache_engine()
|
||||
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
# left-padded attention_mask
|
||||
attention_mask = prompts.batch['attention_mask']
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
|
||||
idx_list = []
|
||||
# parse idx from torch.Tensor to List[List[str]]
|
||||
for i in range(batch_size):
|
||||
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))
|
||||
|
||||
do_sample = prompts.meta_info.get('do_sample', True)
|
||||
if not do_sample:
|
||||
kwargs = {
|
||||
'best_of': 1,
|
||||
'top_p': 1.0,
|
||||
'top_k': -1,
|
||||
'min_p': 0.0,
|
||||
'temperature': 0,
|
||||
'n': 1 # if greedy, only 1 response
|
||||
}
|
||||
|
||||
# users can customize different sampling_params at different run
|
||||
with self.update_sampling_params(**kwargs):
|
||||
output = self.inference_engine.generate(
|
||||
prompts=None, # because we have already convert it to prompt token id
|
||||
sampling_params=self.sampling_params,
|
||||
prompt_token_ids=idx_list,
|
||||
use_tqdm=False)
|
||||
|
||||
# TODO(sgm): disable logprob when recompute_log_prob is enable
|
||||
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
|
||||
response = output[0].to(idx.device)
|
||||
log_probs = output[1].to(idx.device)
|
||||
|
||||
if response.shape[1] < self.config.response_length:
|
||||
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
|
||||
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
|
||||
|
||||
if self.config.n > 1 and do_sample:
|
||||
idx = idx.repeat_interleave(self.config.n, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
|
||||
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
|
||||
batch_size = batch_size * self.config.n
|
||||
seq = torch.cat([idx, response], dim=-1)
|
||||
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# TODO(sgm): fix position_ids on right_pad
|
||||
# prompt: left pad + response: right pad
|
||||
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
batch = TensorDict(
|
||||
{
|
||||
'prompts': idx,
|
||||
'responses': response,
|
||||
'input_ids': seq, # here input_ids become the whole sentences
|
||||
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
# free vllm cache engine
|
||||
if self.config.free_cache_engine:
|
||||
self.inference_engine.free_cache_engine()
|
||||
|
||||
return DataProto(batch=batch)
|
||||
33
verl/workers/sharding_manager/__init__.py
Normal file
33
verl/workers/sharding_manager/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
|
||||
|
||||
from .base import BaseShardingManager
|
||||
from .fsdp_ulysses import FSDPUlyssesShardingManager
|
||||
|
||||
AllGatherPPModel = None
|
||||
|
||||
if is_megatron_core_available() and is_vllm_available():
|
||||
from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
|
||||
elif AllGatherPPModel is not None:
|
||||
pass
|
||||
else:
|
||||
AllGatherPPModel = None
|
||||
MegatronVLLMShardingManager = None
|
||||
|
||||
if is_vllm_available():
|
||||
from .fsdp_vllm import FSDPVLLMShardingManager
|
||||
else:
|
||||
FSDPVLLMShardingManager = None
|
||||
33
verl/workers/sharding_manager/base.py
Normal file
33
verl/workers/sharding_manager/base.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Sharding manager to implement HybridEngine
|
||||
"""
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class BaseShardingManager:
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
return data
|
||||
88
verl/workers/sharding_manager/fsdp_ulysses.py
Normal file
88
verl/workers/sharding_manager/fsdp_ulysses.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
|
||||
"""
|
||||
from typing import Optional
|
||||
from .base import BaseShardingManager
|
||||
|
||||
import random
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl.utils.torch_functional import allgather_dict_tensors
|
||||
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class FSDPUlyssesShardingManager(BaseShardingManager):
|
||||
"""
|
||||
Sharding manager to support data resharding when using FSDP + Ulysses
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
super().__init__()
|
||||
self.device_mesh = device_mesh
|
||||
self.seed_offset = 12345
|
||||
|
||||
def __enter__(self):
|
||||
if self.device_mesh is not None:
|
||||
# We have a global SP group
|
||||
# so we have to change to use model-specific sp group
|
||||
self.prev_sp_group = get_ulysses_sequence_parallel_group()
|
||||
set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
|
||||
# TODO: check how to set seed for each model
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# restore random states
|
||||
if self.device_mesh is not None:
|
||||
# revert to previous sp group
|
||||
set_ulysses_sequence_parallel_group(self.prev_sp_group)
|
||||
# TODO: check how to set seed for each model
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""
|
||||
AllGather data from sp region
|
||||
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
|
||||
In Ulysses, we need to make sure the same data is used across a SP group
|
||||
"""
|
||||
if self.device_mesh is not None:
|
||||
sp_size = self.device_mesh['sp'].size()
|
||||
group = self.device_mesh['sp'].get_group()
|
||||
|
||||
prev_device = data.batch.device
|
||||
data.batch = data.batch.cuda(device=torch.cuda.current_device())
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
|
||||
data.batch = data.batch.to(prev_device)
|
||||
# all gather non_tensor_batch
|
||||
all_non_tensor_batch = [None for _ in range(sp_size)]
|
||||
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
|
||||
data.non_tensor_batch = {
|
||||
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
|
||||
}
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""
|
||||
Split the data to follow FSDP partition
|
||||
"""
|
||||
if self.device_mesh is not None:
|
||||
sp_size = self.device_mesh['sp'].size()
|
||||
sp_rank = self.device_mesh['sp'].get_local_rank()
|
||||
data = data.chunk(chunks=sp_size)[sp_rank]
|
||||
return data
|
||||
133
verl/workers/sharding_manager/fsdp_vllm.py
Normal file
133
verl/workers/sharding_manager/fsdp_vllm.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl.third_party.vllm import LLM
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
|
||||
|
||||
|
||||
class FSDPVLLMShardingManager(BaseShardingManager):
|
||||
|
||||
def __init__(self,
|
||||
module: FSDP,
|
||||
inference_engine: LLM,
|
||||
model_config,
|
||||
full_params: bool = False,
|
||||
device_mesh: DeviceMesh = None):
|
||||
self.module = module
|
||||
self.inference_engine = inference_engine
|
||||
self.model_config = model_config
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
# Full params
|
||||
self.full_params = full_params
|
||||
if full_params:
|
||||
FSDP.set_state_dict_type(self.module,
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=FullStateDictConfig())
|
||||
else:
|
||||
FSDP.set_state_dict_type(self.module,
|
||||
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
||||
state_dict_config=ShardedStateDictConfig())
|
||||
|
||||
# Note that torch_random_states may be different on each dp rank
|
||||
self.torch_random_states = torch.cuda.get_rng_state()
|
||||
# get a random rng states
|
||||
if self.device_mesh is not None:
|
||||
gen_dp_rank = self.device_mesh['dp'].get_local_rank()
|
||||
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
else:
|
||||
self.gen_random_states = None
|
||||
|
||||
def __enter__(self):
|
||||
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
|
||||
params = self.module.state_dict()
|
||||
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
|
||||
# Copy, not share memory
|
||||
load_format = 'hf' if self.full_params else 'dtensor'
|
||||
self.inference_engine.sync_model_weights(params, load_format=load_format)
|
||||
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
|
||||
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
|
||||
|
||||
# TODO: offload FSDP model weights
|
||||
# self.module.cpu()
|
||||
# torch.cuda.empty_cache()
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
|
||||
|
||||
# important: need to manually set the random states of each tp to be identical.
|
||||
if self.device_mesh is not None:
|
||||
self.torch_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.gen_random_states)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger)
|
||||
self.inference_engine.offload_model_weights()
|
||||
log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger)
|
||||
|
||||
# self.module.to('cuda')
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
|
||||
|
||||
self.module.train()
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# restore random states
|
||||
if self.device_mesh is not None:
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=vllm_ps.get_tensor_model_parallel_world_size(),
|
||||
group=vllm_ps.get_tensor_model_parallel_group(),
|
||||
dim=0)
|
||||
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=vllm_ps.get_tensor_model_parallel_src_rank(),
|
||||
group=vllm_ps.get_tensor_model_parallel_group())
|
||||
dp_rank = torch.distributed.get_rank()
|
||||
dp_size = torch.distributed.get_world_size() # not consider torch micro-dp
|
||||
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1:
|
||||
# TODO: shall we build a micro_dp group for vllm when integrating with vLLM?
|
||||
local_prompts = data.chunk(chunks=tp_size)
|
||||
data = local_prompts[dp_rank % tp_size]
|
||||
return data
|
||||
428
verl/workers/sharding_manager/megatron_vllm.py
Normal file
428
verl/workers/sharding_manager/megatron_vllm.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch import nn
|
||||
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import DistributedDataParallel as LocalDDP
|
||||
from megatron.core.transformer.module import Float16Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
from verl.utils.megatron_utils import get_model, unwrap_model
|
||||
from verl.utils.memory_buffer import (
|
||||
build_memory_buffer,
|
||||
build_memory_reference_from_module,
|
||||
get_weight_buffer_meta_from_module,
|
||||
)
|
||||
|
||||
|
||||
class AllGatherPPModel:
|
||||
|
||||
def __init__(self, model_provider) -> None:
|
||||
|
||||
self._pp_group = mpu.get_pipeline_model_parallel_group()
|
||||
self._pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
self._pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
|
||||
self._model_chunk_size = self._vpp_size or 1
|
||||
|
||||
# each one holds a list of model_chunks in this pp stage
|
||||
self._pp_models = [None] * self.pp_size
|
||||
|
||||
rank_list = list(range(self.pp_size))
|
||||
# make current rank the last one to initialize
|
||||
rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
|
||||
self._this_rank_models = None
|
||||
|
||||
# store the parameter of each pp stage
|
||||
self.memory_buffers = [None] * self.pp_size
|
||||
for cur_pp_rank in rank_list:
|
||||
print(
|
||||
f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, '
|
||||
f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB')
|
||||
# since the last initialized rank is the current pp rank, after init, the pp rank is still correct
|
||||
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
models = get_model(model_provider, wrap_with_ddp=False)
|
||||
models = nn.ModuleList(models)
|
||||
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
||||
self.pp_models[cur_pp_rank] = models
|
||||
else:
|
||||
# for regular model, we wrapped it with DDP
|
||||
models = get_model(model_provider)
|
||||
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
||||
self._this_rank_models = nn.ModuleList(models)
|
||||
self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))
|
||||
|
||||
self._build_param_buffer(cur_pp_rank)
|
||||
self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)
|
||||
|
||||
# TODO: after binding to the memory buffer, we can load the checkpoint here
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
for model in self.pp_models[cur_pp_rank]:
|
||||
model.eval()
|
||||
self._offload_params_to_cpu(cur_pp_rank)
|
||||
|
||||
def _build_param_buffer(self, pp_rank):
|
||||
"""Build the parameter buffer in each pp rank"""
|
||||
model = self.pp_models[pp_rank]
|
||||
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
|
||||
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
|
||||
|
||||
def _build_param_references(self, pp_rank, maintain_weight=False):
|
||||
model = self.pp_models[pp_rank]
|
||||
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
|
||||
|
||||
def _load_params_to_cuda(self, pp_rank, to_empty=False):
|
||||
assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
|
||||
for buffer in self.memory_buffers[pp_rank].values():
|
||||
if not to_empty:
|
||||
buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
|
||||
else:
|
||||
buffer.data = torch.empty_like(buffer.data, device='cuda')
|
||||
# rebuild reference after loading to CUDA
|
||||
self._build_param_references(pp_rank)
|
||||
|
||||
def _offload_params_to_cpu(self, pp_rank, to_empty=False):
|
||||
assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
|
||||
for buffer in self.memory_buffers[pp_rank].values():
|
||||
if not to_empty:
|
||||
# offload the whole memory buffer to CPU
|
||||
buffer.data = buffer.data.to('cpu', non_blocking=True)
|
||||
else:
|
||||
buffer.data = torch.empty_like(buffer.data, device='cpu')
|
||||
self._build_param_references(pp_rank)
|
||||
|
||||
def load_params_to_cuda(self, to_empty=False):
|
||||
"""load all model params to cuda"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)
|
||||
|
||||
def allgather_params(self):
|
||||
"""allgather params of all pp ranks. Return a list of handles"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
|
||||
|
||||
# NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models
|
||||
for memory_buffer in self.memory_buffers[cur_pp_rank].values():
|
||||
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
try:
|
||||
prev_output = None
|
||||
for cur_chunk_rank in range(self._model_chunk_size):
|
||||
if self._vpp_size:
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)
|
||||
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
||||
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
|
||||
ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
|
||||
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
|
||||
prev_output = ret
|
||||
finally:
|
||||
if self._vpp_size:
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(0)
|
||||
mpu.set_pipeline_model_parallel_rank(self.pp_rank)
|
||||
return ret
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
return self.forward(*inputs, **kwargs)
|
||||
|
||||
def eval(self):
|
||||
for model in self.pp_models[self.pp_rank]:
|
||||
model.eval()
|
||||
|
||||
def train(self):
|
||||
for model in self.pp_models[self.pp_rank]:
|
||||
model.train()
|
||||
|
||||
def offload_params_to_cpu(self, to_empty=False):
|
||||
"""offload params of models that are not of current pp rank to cpu"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)
|
||||
|
||||
def get_all_params(self):
|
||||
"""Get all the parameters of the models in all pp ranks
|
||||
|
||||
Returns:
|
||||
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
|
||||
tensors of each model chunk
|
||||
|
||||
"""
|
||||
params = []
|
||||
for pp_rank in range(self.pp_size):
|
||||
params.append([])
|
||||
for model_chunk_idx in range(len(self.pp_models[pp_rank])):
|
||||
params[pp_rank].append({})
|
||||
pp_model = self.pp_models[pp_rank][model_chunk_idx]
|
||||
pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module
|
||||
for name, param in pp_model.named_parameters():
|
||||
# NOTE(gh) workaround: should not get lora params for inference
|
||||
if 'lora' in name:
|
||||
continue
|
||||
params[pp_rank][model_chunk_idx][name] = param
|
||||
|
||||
return params
|
||||
|
||||
def update_this_rank_models(self, new_models):
|
||||
self._this_rank_models = new_models
|
||||
self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))
|
||||
|
||||
@property
|
||||
def this_rank_models(self):
|
||||
return self._this_rank_models
|
||||
|
||||
@property
|
||||
def pp_size(self):
|
||||
return self._pp_size
|
||||
|
||||
@property
|
||||
def pp_rank(self):
|
||||
return self._pp_rank
|
||||
|
||||
@property
|
||||
def pp_group(self):
|
||||
return self._pp_group
|
||||
|
||||
@property
|
||||
def pp_models(self):
|
||||
return self._pp_models
|
||||
|
||||
|
||||
"""
|
||||
Megatron Hybrid Engine:
|
||||
- During training, only the current pp stage holds the parameters
|
||||
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
|
||||
- Bind the parameters to the inference engine
|
||||
- Do inference in tp. pp is treated as additional dp
|
||||
- After inference, all the parameters that doesn't belong to this pp rank is freed.
|
||||
"""
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed
|
||||
from torch.distributed import new_group
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
|
||||
import verl.utils.megatron.tensor_parallel as tp_utils
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from verl.third_party.vllm import LLM
|
||||
from verl.utils.model import normalize_pp_vpp_params
|
||||
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
|
||||
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
|
||||
_MICRO_DATA_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
class MegatronVLLMShardingManager(BaseShardingManager):
|
||||
|
||||
def __init__(self, module: AllGatherPPModel, inference_engine: LLM, model_config, layer_name_mapping):
|
||||
self.module = module
|
||||
self.inference_engine = inference_engine
|
||||
self.model_config = model_config
|
||||
self.layer_name_mapping = layer_name_mapping
|
||||
|
||||
# initialize micro_dp group for vllm inference
|
||||
global _MICRO_DATA_PARALLEL_GROUP
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
train_tensor_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
infer_tensor_parallel_size = vllm_ps.get_tensor_model_parallel_world_size()
|
||||
|
||||
# TODO(sgm): this may not be true for FSDP -> vLLM
|
||||
assert infer_tensor_parallel_size <= train_tensor_parallel_size, \
|
||||
'Not implemented for infer_tp > train_tp'
|
||||
assert train_tensor_parallel_size % infer_tensor_parallel_size == 0
|
||||
|
||||
micro_dp_size = train_tensor_parallel_size // infer_tensor_parallel_size
|
||||
num_micro_dp_groups = world_size // micro_dp_size
|
||||
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
|
||||
for i in range(num_micro_dp_groups):
|
||||
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
|
||||
group = new_group(ranks=ranks)
|
||||
if rank in ranks:
|
||||
_MICRO_DATA_PARALLEL_GROUP = group
|
||||
|
||||
def default_tp_concat_fn(self, name, param, infer_params, model_config):
|
||||
"""
|
||||
name: name of the parameter
|
||||
param: training parameters
|
||||
infer_params (List[torch.Tensor]): a list of parameters all-gathered from micro_dp_group
|
||||
model_config: huggingface model_config
|
||||
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
|
||||
definition so that it is model-agnostic. If the model doesn't implement this function,
|
||||
we can throw an error to force user disable TP HybridEngine.
|
||||
"""
|
||||
|
||||
if self.layer_name_mapping.get("qkv_layer_name") in name:
|
||||
# if the tensor is qkv, for each param on tp, split into q, k, v
|
||||
# concat q, k, v separately.
|
||||
q_lst = []
|
||||
k_lst = []
|
||||
v_lst = []
|
||||
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
|
||||
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
|
||||
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
|
||||
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
|
||||
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
|
||||
for infer_param in infer_params:
|
||||
q, k, v = infer_param.split(split_size)
|
||||
q_lst.append(q)
|
||||
k_lst.append(k)
|
||||
v_lst.append(v)
|
||||
q = torch.cat(q_lst, dim=0)
|
||||
k = torch.cat(k_lst, dim=0)
|
||||
v = torch.cat(v_lst, dim=0)
|
||||
|
||||
infer_params = torch.cat((q, k, v), dim=0)
|
||||
|
||||
elif self.layer_name_mapping.get("gate_proj_layer_name") in name:
|
||||
# if the tensor is gate and proj
|
||||
gate_lst = []
|
||||
up_lst = []
|
||||
for infer_param in infer_params:
|
||||
gate, up = infer_param.chunk(2)
|
||||
gate_lst.append(gate)
|
||||
up_lst.append(up)
|
||||
gate = torch.cat(gate_lst, dim=0)
|
||||
up = torch.cat(up_lst, dim=0)
|
||||
infer_params = torch.cat((gate, up), dim=0)
|
||||
|
||||
else:
|
||||
# concat tensor
|
||||
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))
|
||||
|
||||
return infer_params
|
||||
|
||||
def _post_process_params(self, params):
|
||||
"""
|
||||
For each param, if it is a tp-splited param, we all-gather from micro_dp group.
|
||||
"""
|
||||
# here the params are in train tp format. we iterate params and all-gather
|
||||
# TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer.
|
||||
# In this way, all the params in the original memory_buffers and can be offload.
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
micro_dp_group = get_micro_data_parallel_group()
|
||||
|
||||
if micro_dp_size <= 1:
|
||||
return
|
||||
|
||||
origin_params = {}
|
||||
for name in params.keys():
|
||||
param = params[name]
|
||||
if tp_utils.is_tensor_parallel_param(param):
|
||||
# allocate a new tensor with proper size
|
||||
infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)]
|
||||
torch.distributed.all_gather(infer_params, param, group=micro_dp_group)
|
||||
infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config)
|
||||
# replace with original param
|
||||
params[name] = infer_params
|
||||
origin_params[name] = param
|
||||
|
||||
return origin_params
|
||||
|
||||
def __enter__(self):
|
||||
# create a new cuda space for parameters not in this pp rank
|
||||
self.module.load_params_to_cuda()
|
||||
# broadcast the parameters from pp rank to other ranks
|
||||
self.module.allgather_params()
|
||||
# obtain name to parameters in pp/vpp
|
||||
params = self.module.get_all_params()
|
||||
|
||||
# bind the params to inference engine
|
||||
self.params = normalize_pp_vpp_params(params=params,
|
||||
num_hidden_layers=self.model_config.num_hidden_layers,
|
||||
layer_name='layers')
|
||||
self.origin_params = self._post_process_params(self.params)
|
||||
self.inference_engine.sync_model_weights(self.params, load_format='megatron')
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# offload parameters doesn't belong to this pp rank
|
||||
self.module.offload_params_to_cpu()
|
||||
|
||||
# FIXME(sgm): the best practice is to delete the cuda tensor
|
||||
# rebind the model weights, can be any cpu tensor
|
||||
if get_micro_data_parallel_world_size() > 1:
|
||||
for name in self.params.keys():
|
||||
self.params[name] = self.origin_params[name]
|
||||
|
||||
# self.inference_engine.sync_model_weights(params)
|
||||
self.inference_engine.offload_model_weights()
|
||||
|
||||
self.module.train()
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# prompts are identical for each training tp. We select for each inference tp
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
micro_dp_rank = get_micro_data_parallel_rank()
|
||||
|
||||
# broadcast from tp=0 to other tp ranks
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_tensor_model_parallel_src_rank(),
|
||||
group=mpu.get_tensor_model_parallel_group())
|
||||
|
||||
if micro_dp_size > 1:
|
||||
local_prompts = data.chunk(chunks=micro_dp_size)
|
||||
data = local_prompts[micro_dp_rank]
|
||||
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
meta_info = data.meta_info
|
||||
# all gather batch among micro-dp groups
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
if micro_dp_size > 1:
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=get_micro_data_parallel_world_size(),
|
||||
group=get_micro_data_parallel_group(),
|
||||
dim=0)
|
||||
|
||||
# all gather batch among pp group
|
||||
if meta_info.get('allgather_pp_output', True):
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=mpu.get_pipeline_model_parallel_world_size(),
|
||||
group=mpu.get_pipeline_model_parallel_group(),
|
||||
dim=0)
|
||||
return data
|
||||
|
||||
|
||||
"""
|
||||
Micro Data parallel group
|
||||
"""
|
||||
|
||||
|
||||
def get_micro_data_parallel_group():
|
||||
assert _MICRO_DATA_PARALLEL_GROUP is not None
|
||||
return _MICRO_DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_micro_data_parallel_world_size():
|
||||
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
|
||||
|
||||
|
||||
def get_micro_data_parallel_rank():
|
||||
return torch.distributed.get_rank(group=get_micro_data_parallel_group())
|
||||
Reference in New Issue
Block a user