Initial commit
This commit is contained in:
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
@@ -0,0 +1,656 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch LLaMA model with Megatron-style acceleration."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from megatron.core import tensor_parallel
|
||||
from megatron.core import ModelParallelConfig
|
||||
from torch import nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
|
||||
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
|
||||
"""
|
||||
TODO:
|
||||
1. Add weight initialization. Here we need to be careful on TP weight init.
|
||||
2. Add sequence parallel
|
||||
3. Load checkpoint from meta LLama pretrained checkpoint
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class ParallelLlamaModel(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
||||
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
||||
combined_attention_mask)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (batch_size, seq_length)
|
||||
attention_mask: attention_mask. shape (batch_size, seq_length)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
|
||||
output_size=config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = outputs
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
|
||||
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
|
||||
|
||||
logits = logits.float()
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
class ParallelLlamaModelRmPad(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
self.megatron_config = megatron_config
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (1, totol_nnz)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||
|
||||
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||
if self.megatron_config.sequence_parallel:
|
||||
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(hidden_states,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLMRmPad(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self._init_head()
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||
output_size=self.config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
# all_gather from sequence parallel region is performed inside lm_head
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
# remove padding here
|
||||
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||
attention_mask) # (total_nnz, 1)
|
||||
|
||||
# pad input_ids to multiple of tp for all tp ranks
|
||||
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||
if self.megatron_config.sequence_parallel:
|
||||
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
|
||||
|
||||
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
|
||||
|
||||
outputs = self.model(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = outputs
|
||||
|
||||
logits = self._forward_head(hidden_states)
|
||||
|
||||
# remove padding from sequence parallel
|
||||
if self.megatron_config.sequence_parallel:
|
||||
totol_nnz = cu_seqlens[-1]
|
||||
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||
|
||||
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
|
||||
# add removed padding back
|
||||
logits = pad_input(logits, indices, batch_size,
|
||||
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||
# lm_head is effectively the same as sequence parallel
|
||||
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||
logits = logits.float()
|
||||
if self.megatron_config.sequence_parallel:
|
||||
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output = super().forward(input_ids, attention_mask, position_ids)
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
"""
|
||||
Support pipeline parallelism
|
||||
"""
|
||||
|
||||
|
||||
class ParallelLlamaModelRmPadPP(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
This model definition supports pipeline parallelism. To support pp and vpp,
|
||||
- This model only contains layer in this pp stage and vpp chunk
|
||||
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.megatron_config = megatron_config
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
if pre_process:
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
else:
|
||||
self.embed_tokens = None
|
||||
|
||||
# pp_rank = megatron_config.pipeline_model_parallel_rank
|
||||
pp_size = megatron_config.pipeline_model_parallel_size
|
||||
self.num_layer_per_pp = config.num_hidden_layers // pp_size
|
||||
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
|
||||
|
||||
if vpp_size is not None:
|
||||
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
|
||||
self.num_layer_this_model = self.num_layer_vpp_chunk
|
||||
# vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank
|
||||
# self.offset = vpp_rank * (
|
||||
# config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \
|
||||
# (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk)
|
||||
else:
|
||||
self.num_layer_this_model = self.num_layer_per_pp
|
||||
# self.offset = pp_rank * self.num_layer_per_pp
|
||||
|
||||
layers = []
|
||||
for i in range(self.num_layer_this_model):
|
||||
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config)
|
||||
# setattr(layer, 'hidden_layer_index', self.offset + i)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
if post_process:
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (1, totol_nnz)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.pre_process:
|
||||
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||
|
||||
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
|
||||
# so need to deal with it by handle here:
|
||||
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||
if self.megatron_config.sequence_parallel:
|
||||
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
# self.hidden_states should be passed by Megatron
|
||||
hidden_states = self.input_tensor
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(hidden_states,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if self.post_process:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.model = ParallelLlamaModelRmPadPP(config,
|
||||
megatron_config=megatron_config,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process)
|
||||
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
if post_process:
|
||||
self._init_head()
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
assert len(input_tensor) == 1
|
||||
self.model.set_input_tensor(input_tensor[0])
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||
output_size=self.config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
# all_gather from sequence parallel region is performed inside lm_head
|
||||
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
# logits shape after forward_head logits.shape: [8, 32, 8]
|
||||
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# original input
|
||||
*,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
|
||||
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
|
||||
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
# remove padding here
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||
attention_mask) # (total_nnz, 1)
|
||||
|
||||
# pad input_ids to multiple of tp for all tp ranks
|
||||
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||
if self.megatron_config.sequence_parallel:
|
||||
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
|
||||
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
|
||||
|
||||
outputs = self.model(input_ids=input_ids_rmpad,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
if self.post_process:
|
||||
hidden_states = outputs
|
||||
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
|
||||
logits = self._forward_head(hidden_states)
|
||||
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
|
||||
|
||||
# remove padding from sequence parallel
|
||||
if self.megatron_config.sequence_parallel:
|
||||
totol_nnz = cu_seqlens[-1]
|
||||
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||
# add removed padding back. If input is already rmpad, we let the caller pad_input
|
||||
logits = pad_input(logits, indices, batch_size,
|
||||
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||
# lm_head is effectively the same as sequence parallel
|
||||
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||
logits = logits.float()
|
||||
if self.megatron_config.sequence_parallel:
|
||||
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
if self.post_process:
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
else:
|
||||
return output
|
||||
Reference in New Issue
Block a user