146 lines
6.6 KiB
Python
146 lines
6.6 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import torch
|
|
from typing import Optional, List, Union, Tuple, Unpack, Callable
|
|
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
from transformers.cache_utils import Cache
|
|
from transformers.utils import logging
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
def llama_flash_attn_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""
|
|
adapt from transformers 4.47.1
|
|
"""
|
|
output_attentions = False
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
# Flash attention requires the input to have the shape
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
# therefore we just need to keep the original shape
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
# trade off: repeat first and then all to all
|
|
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
########## AlltoAll for Ulysses ##########
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
if ulysses_sp_size > 1:
|
|
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
full_q_len = query_states.size(2) # full seq length
|
|
|
|
if position_embeddings is None:
|
|
logger.warning_once(
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
"removed and `position_embeddings` will be mandatory.")
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
else:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
# to be able to avoid many of these transpose/reshape/view.
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
|
|
|
input_dtype = query_states.dtype
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}.")
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
key_states = key_states.to(target_dtype)
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
attn_output = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
full_q_len,
|
|
position_ids=position_ids,
|
|
dropout=dropout_rate,
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
is_causal=self.is_causal,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
########## AlltoAll for Ulysses ##########
|
|
if ulysses_sp_size > 1:
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|