Initial commit
This commit is contained in:
13
verl/utils/megatron/__init__.py
Normal file
13
verl/utils/megatron/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
41
verl/utils/megatron/memory.py
Normal file
41
verl/utils/megatron/memory.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
|
||||
def __init__(self, numel, numel_padded, dtype):
|
||||
self.numel = numel
|
||||
self.numel_padded = numel_padded
|
||||
self.dtype = dtype
|
||||
self.data = torch.zeros(self.numel_padded,
|
||||
dtype=self.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
|
||||
def zero(self):
|
||||
"""Reset the buffer to zero."""
|
||||
self.data.zero_()
|
||||
|
||||
def get(self, shape, start_index):
|
||||
"""Return a tensor with the input `shape` as a view into the
|
||||
1-D data starting at `start_index`."""
|
||||
end_index = start_index + shape.numel()
|
||||
assert end_index <= self.numel, \
|
||||
'requested tensor is out of the buffer range.'
|
||||
buffer_tensor = self.data[start_index:end_index]
|
||||
buffer_tensor = buffer_tensor.view(shape)
|
||||
return buffer_tensor
|
||||
92
verl/utils/megatron/optimizer.py
Normal file
92
verl/utils/megatron/optimizer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from apex.optimizers import FusedAdam as Adam
|
||||
from apex.optimizers import FusedSGD as SGD
|
||||
from megatron.optimizer.distrib_optimizer import DistributedOptimizer
|
||||
from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
|
||||
from megatron.optimizer import get_param_groups
|
||||
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
|
||||
|
||||
def get_megatron_optimizer(
|
||||
model,
|
||||
config: OptimizerConfig,
|
||||
no_weight_decay_cond=None,
|
||||
scale_lr_cond=None,
|
||||
lr_mult=1.0,
|
||||
check_for_nan_in_loss_and_grad=False,
|
||||
overlap_param_gather=False # add for verl
|
||||
):
|
||||
# Base optimizer.
|
||||
param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult)
|
||||
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = Adam(param_groups,
|
||||
lr=config.lr,
|
||||
weight_decay=config.weight_decay,
|
||||
betas=(config.adam_beta1, config.adam_beta2),
|
||||
eps=config.adam_eps)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = SGD(param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum)
|
||||
else:
|
||||
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
|
||||
|
||||
# Determine whether the params have main-grad field.
|
||||
params_have_main_grad = True
|
||||
|
||||
# Mixed precision optimizer.
|
||||
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
|
||||
# from the MixedPrecisionOptimizer, which manages any optimizer where
|
||||
# the model params and main params are distinct.
|
||||
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
|
||||
|
||||
# Grad scaler:
|
||||
# if loss-scale is provided, instantiate the constant scaler.
|
||||
# if we are using fp16 and loss-scale is not present, use a
|
||||
# dynamic scaler.
|
||||
# otherwise we are running in bf16 with no loss-scale so
|
||||
# leave it as None.
|
||||
grad_scaler = None
|
||||
|
||||
# Constant loss scale.
|
||||
if config.loss_scale:
|
||||
grad_scaler = ConstantGradScaler(config.loss_scale)
|
||||
|
||||
# Dynamic loss scale.
|
||||
else:
|
||||
if config.fp16:
|
||||
grad_scaler = DynamicGradScaler(initial_scale=config.initial_loss_scale,
|
||||
min_scale=config.min_loss_scale,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=config.loss_scale_window,
|
||||
hysteresis=config.hysteresis)
|
||||
|
||||
# Megatron optimizer.
|
||||
if config.use_distributed_optimizer:
|
||||
return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
|
||||
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16,
|
||||
config.params_dtype, grad_scaler, model, overlap_param_gather)
|
||||
else:
|
||||
return Float16OptimizerWithFloat16Params(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
|
||||
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16,
|
||||
config.bf16, config.params_dtype, grad_scaler, model)
|
||||
|
||||
# FP32.
|
||||
return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad,
|
||||
params_have_main_grad, model)
|
||||
129
verl/utils/megatron/optimizer_config.py
Normal file
129
verl/utils/megatron/optimizer_config.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for optimizer."""
|
||||
|
||||
##############
|
||||
# General
|
||||
##############
|
||||
optimizer: str = 'adam'
|
||||
"""Optimizer to use (one of Adam or SGD)."""
|
||||
|
||||
lr: Optional[float] = None
|
||||
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
|
||||
iteration would be different.
|
||||
"""
|
||||
|
||||
min_lr: Optional[float] = None
|
||||
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
|
||||
|
||||
decoupled_lr: Optional[float] = None
|
||||
"""Separate learning rate for the input and output layer."""
|
||||
|
||||
decoupled_min_lr: Optional[float] = None
|
||||
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
|
||||
below this threshold.
|
||||
"""
|
||||
|
||||
weight_decay: float = 0.01
|
||||
"""Weight decay coefficient for L2 regularization."""
|
||||
|
||||
##############
|
||||
# Precision
|
||||
##############
|
||||
fp16: bool = False
|
||||
"""If true, train with fp16 mixed precision training. Defaults to False."""
|
||||
|
||||
bf16: bool = False
|
||||
"""If true, train with bf16 mixed precision training. Defaults to False."""
|
||||
|
||||
params_dtype: torch.dtype = torch.float32
|
||||
"""dtype used when intializing the weights. Defaults to torch.float32."""
|
||||
|
||||
###############
|
||||
# Loss scaling
|
||||
###############
|
||||
loss_scale: Optional[float] = None
|
||||
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
|
||||
dynamic loss scaling is used.
|
||||
"""
|
||||
|
||||
initial_loss_scale: float = 2**32
|
||||
"""Initial loss-scale for dynamic loss scaling."""
|
||||
|
||||
min_loss_scale: float = 1.0
|
||||
"""Minimum loss scale for dynamic loss scaling."""
|
||||
|
||||
loss_scale_window: float = 1000
|
||||
"""Window over which to raise/lower dynamic scale."""
|
||||
|
||||
hysteresis: int = 2
|
||||
"""Hysteresis for dynamic loss scaling."""
|
||||
|
||||
##############
|
||||
# Optimizer
|
||||
##############
|
||||
# Adam
|
||||
adam_beta1: float = 0.9
|
||||
"""First coefficient for computing running averages of gradient and its square in Adam
|
||||
optimizer.
|
||||
"""
|
||||
|
||||
adam_beta2: float = 0.999
|
||||
"""Second coefficient for computing running averages of gradient and its square in Adam
|
||||
optimizer.
|
||||
"""
|
||||
|
||||
adam_eps: float = 1e-08
|
||||
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
|
||||
|
||||
# SGD.
|
||||
sgd_momentum: float = 0.9
|
||||
"""Momentum factor for SGD optimizer."""
|
||||
|
||||
#######################
|
||||
# Distributed optimizer
|
||||
#######################
|
||||
use_distributed_optimizer: bool = False
|
||||
"""Distribute optimizer state over data-parallel replicas."""
|
||||
|
||||
overlap_grad_reduce: bool = False
|
||||
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
|
||||
|
||||
overlap_param_gather: bool = False
|
||||
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
|
||||
|
||||
################
|
||||
# Miscellaneous
|
||||
################
|
||||
clip_grad: float = 1.0
|
||||
"""Gradient clipping based on global L2 norm."""
|
||||
|
||||
log_num_zeros_in_grad: bool = False
|
||||
"""If true, calculate and log the number of zeros in gradient."""
|
||||
|
||||
barrier_with_L1_time: bool = False
|
||||
"""If true, use barrier with level 1 time measurements."""
|
||||
|
||||
timers: Callable = None
|
||||
"""Function to get timers."""
|
||||
51
verl/utils/megatron/pipeline_parallel.py
Normal file
51
verl/utils/megatron/pipeline_parallel.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from megatron.core import parallel_state as mpu
|
||||
|
||||
from .sequence_parallel import pad_to_sequence_parallel
|
||||
|
||||
|
||||
def compute_transformers_input_shapes(batches, meta_info):
|
||||
from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
|
||||
# pre-compute input shapes for each micro-batch at each pp stage
|
||||
input_shapes = []
|
||||
for model_inputs in batches:
|
||||
input_ids = model_inputs['input_ids']
|
||||
attention_mask = model_inputs['attention_mask']
|
||||
input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
|
||||
if meta_info['sequence_parallel']:
|
||||
input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
|
||||
# compute shapes for model_inputs
|
||||
input_shapes.append(
|
||||
torch.Size([
|
||||
input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
|
||||
]))
|
||||
else:
|
||||
# compute shapes for model_inputs
|
||||
input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
|
||||
return input_shapes
|
||||
|
||||
|
||||
def make_batch_generator(batches, vpp_size):
|
||||
if vpp_size > 1:
|
||||
# has vpp
|
||||
batch_generator = [batches] * vpp_size # number of vpp chunks
|
||||
batch_generator = [iter(b) for b in batch_generator]
|
||||
else:
|
||||
# no vpp
|
||||
batch_generator = iter(batches)
|
||||
return batch_generator
|
||||
54
verl/utils/megatron/sequence_parallel.py
Normal file
54
verl/utils/megatron/sequence_parallel.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from megatron.core import parallel_state as mpu
|
||||
|
||||
|
||||
def mark_parameter_as_sequence_parallel(parameter):
|
||||
setattr(parameter, 'sequence_parallel', True)
|
||||
|
||||
|
||||
def is_sequence_parallel_param(param):
|
||||
return hasattr(param, 'sequence_parallel') and param.sequence_parallel
|
||||
|
||||
|
||||
def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):
|
||||
"""pad the tokens such that the total length is a multiple of sp world size
|
||||
|
||||
Args:
|
||||
unpad_tokens: (total_nnz, ...). Tokens after removing padding
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
total_nnz = unpad_tokens.shape[0]
|
||||
sp_world_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if total_nnz % sp_world_size == 0:
|
||||
pad_size = 0
|
||||
else:
|
||||
pad_size = sp_world_size - total_nnz % sp_world_size
|
||||
|
||||
if pad_size > 0:
|
||||
if unpad_tokens.ndim == 1:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
|
||||
elif unpad_tokens.ndim == 2:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
|
||||
else:
|
||||
raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
|
||||
|
||||
return unpad_tokens
|
||||
184
verl/utils/megatron/tensor_parallel.py
Normal file
184
verl/utils/megatron/tensor_parallel.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for using tensor_parallel in megatron
|
||||
"""
|
||||
from typing import Dict
|
||||
import torch
|
||||
from torch.nn import init
|
||||
import torch.distributed as dist
|
||||
from megatron.core import ModelParallelConfig
|
||||
from megatron.core import parallel_state as mpu, tensor_parallel
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def update_kwargs_with_config(dictionary: Dict, config: ModelParallelConfig):
|
||||
dictionary['config'] = config
|
||||
return dictionary
|
||||
|
||||
|
||||
def get_default_kwargs_for_model_parallel_config():
|
||||
model_parallel_config_kwargs = {
|
||||
'params_dtype': torch.float32,
|
||||
'use_cpu_initialization': False,
|
||||
'perform_initialization': True,
|
||||
'gradient_accumulation_fusion': False,
|
||||
'sequence_parallel': False,
|
||||
}
|
||||
return model_parallel_config_kwargs
|
||||
|
||||
|
||||
def get_default_model_parallel_config():
|
||||
return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())
|
||||
|
||||
|
||||
def get_common_default_kwargs_for_parallel_linear():
|
||||
default_model_parallel_config = get_default_model_parallel_config()
|
||||
common_default_kwargs = {
|
||||
'init_method': init.xavier_normal_,
|
||||
'stride': 1,
|
||||
'keep_master_weight_for_test': False,
|
||||
'config': default_model_parallel_config,
|
||||
}
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_column_parallel_linear():
|
||||
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
|
||||
column_parallel_config_kwargs = {
|
||||
'async_tensor_model_parallel_allreduce': False,
|
||||
}
|
||||
model_parallel_config_kwargs.update(column_parallel_config_kwargs)
|
||||
column_default_kwargs = {
|
||||
'config': ModelParallelConfig(**model_parallel_config_kwargs),
|
||||
}
|
||||
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
|
||||
common_default_kwargs.update(column_default_kwargs)
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_row_parallel_linear():
|
||||
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_parallel_embedding():
|
||||
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
|
||||
embedding_default_kwargs = {
|
||||
'init_method': init.xavier_normal_,
|
||||
'config': ModelParallelConfig(**model_parallel_config_kwargs),
|
||||
}
|
||||
return embedding_default_kwargs
|
||||
|
||||
|
||||
def is_tensor_parallel_param(param):
|
||||
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel)
|
||||
|
||||
|
||||
def get_tensor_parallel_partition_dim(param):
|
||||
assert is_tensor_parallel_param(param)
|
||||
return param.partition_dim
|
||||
|
||||
|
||||
def get_tensor_parallel_partition_stride(param):
|
||||
assert is_tensor_parallel_param(param)
|
||||
return param.partition_stride
|
||||
|
||||
|
||||
class _VocabParallelEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
|
||||
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
|
||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())
|
||||
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
|
||||
normalized_exp_logits = normalized_vocab_parallel_logits.exp()
|
||||
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
|
||||
dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())
|
||||
softmax_logits = normalized_exp_logits / normalized_sum_exp_logits
|
||||
sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True)
|
||||
dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())
|
||||
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
|
||||
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
|
||||
return entropy.squeeze(dim=-1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
||||
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
|
||||
grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits)
|
||||
return grad_input
|
||||
|
||||
|
||||
def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute entropy when the logits are sharded in tp ranks
|
||||
|
||||
Args:
|
||||
vocab_parallel_logits: (total_nnz, vocab_size // tp_size)
|
||||
|
||||
Returns: (total_nnz,)
|
||||
|
||||
"""
|
||||
return _VocabParallelEntropy.apply(vocab_parallel_logits)
|
||||
|
||||
|
||||
def vocab_parallel_log_probs_from_logits(logits, labels):
|
||||
"""TODO(zhangchi.usc1992): We may change the implementation later"""
|
||||
return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)
|
||||
|
||||
|
||||
def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
|
||||
"""Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region.
|
||||
This will further reduce the peak memory usage during training
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seqlen]
|
||||
attention_mask: [batch_size, seqlen]
|
||||
logits_rmpad: [total_nnz, vocab_size // tp_size]
|
||||
response_length: int
|
||||
|
||||
"""
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
batch_size, seqlen = input_ids.shape
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
|
||||
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
|
||||
full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad,
|
||||
labels=input_ids_rmpad_rolled) # (total_nnz,)
|
||||
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
|
||||
return output
|
||||
|
||||
|
||||
def vocab_parallel_compute_entropy_loss(logits, eos_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
entropy: a scalar torch.Tensor
|
||||
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = vocab_parallel_entropy(logits)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
return entropy_loss
|
||||
Reference in New Issue
Block a user