138 lines
6.0 KiB
Python
138 lines
6.0 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import torch
|
|
from typing import Optional, Tuple
|
|
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
from transformers.cache_utils import Cache
|
|
from transformers.utils import logging
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def qwen2_flash_attn_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
|
):
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
########## AlltoAll for Ulysses ##########
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
if ulysses_sp_size > 1:
|
|
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
full_q_len = query_states.size(2) # full seq length
|
|
|
|
if position_embeddings is None:
|
|
logger.warning_once(
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
"removed and `position_embeddings` will be mandatory.")
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
else:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# repeat k/v heads if n_kv_heads < n_heads
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in float16 just to be sure everything works as expected.
|
|
input_dtype = query_states.dtype
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}.")
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
key_states = key_states.to(target_dtype)
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
# Reashape to the expected shape for Flash Attention
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
|
|
self.layer_idx >= self.config.max_window_layers):
|
|
sliding_window = self.config.sliding_window
|
|
else:
|
|
sliding_window = None
|
|
|
|
attn_output = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
full_q_len,
|
|
position_ids=position_ids,
|
|
dropout=dropout_rate,
|
|
sliding_window=sliding_window,
|
|
is_causal=self.is_causal,
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
)
|
|
|
|
# use full_q_len to reshape
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
########## AlltoAll for Ulysses ##########
|
|
if ulysses_sp_size > 1:
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|