Initial commit
This commit is contained in:
13
verl/models/transformers/__init__.py
Normal file
13
verl/models/transformers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
145
verl/models/transformers/llama.py
Normal file
145
verl/models/transformers/llama.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional, List, Union, Tuple, Unpack, Callable
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.utils import logging
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def llama_flash_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
adapt from transformers 4.47.1
|
||||
"""
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# trade off: repeat first and then all to all
|
||||
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
########## AlltoAll for Ulysses ##########
|
||||
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
if ulysses_sp_size > 1:
|
||||
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
||||
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
||||
|
||||
full_q_len = query_states.size(2) # full seq length
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory.")
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}.")
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
full_q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
||||
########## AlltoAll for Ulysses ##########
|
||||
if ulysses_sp_size > 1:
|
||||
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
74
verl/models/transformers/monkey_patch.py
Normal file
74
verl/models/transformers/monkey_patch.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Apply monkey-patch function to models
|
||||
"""
|
||||
|
||||
#### Open Source Models
|
||||
#### transformers version < 4.48
|
||||
|
||||
|
||||
def apply_monkey_patch_to_llama():
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
from verl.models.transformers.llama import llama_flash_attn_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||
|
||||
|
||||
def apply_monkey_patch_to_qwen2():
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
|
||||
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
|
||||
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
|
||||
|
||||
|
||||
_PATCH_NAME_TO_FUNC = {
|
||||
'llama': apply_monkey_patch_to_llama,
|
||||
'qwen2': apply_monkey_patch_to_qwen2,
|
||||
}
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
|
||||
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
|
||||
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
|
||||
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.")
|
||||
success_apply_monkey_patch = False
|
||||
if config.model_type in _PATCH_NAME_TO_FUNC:
|
||||
_PATCH_NAME_TO_FUNC[config.model_type]()
|
||||
success_apply_monkey_patch = True
|
||||
|
||||
if success_apply_monkey_patch and verbose:
|
||||
print(f'Applying monkey patch to model {config.model_type}')
|
||||
elif not success_apply_monkey_patch:
|
||||
raise NotImplementedError(f'Ulysses for model {config.model_type} is not implemented, \
|
||||
please set `ulysses_sequence_parallel_size=1`')
|
||||
|
||||
return success_apply_monkey_patch
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from packaging import version
|
||||
import importlib.metadata
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
|
||||
try:
|
||||
# Get the installed version of the transformers library
|
||||
transformers_version = importlib.metadata.version("transformers")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
raise ModuleNotFoundError("The `transformers` package is not installed.")
|
||||
|
||||
# Check if the version is within the specified range
|
||||
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
|
||||
137
verl/models/transformers/qwen2.py
Normal file
137
verl/models/transformers/qwen2.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.utils import logging
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def qwen2_flash_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
########## AlltoAll for Ulysses ##########
|
||||
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
if ulysses_sp_size > 1:
|
||||
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
||||
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
||||
|
||||
full_q_len = query_states.size(2) # full seq length
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory.")
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}.")
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
|
||||
self.layer_idx >= self.config.max_window_layers):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
full_q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
# use full_q_len to reshape
|
||||
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
||||
########## AlltoAll for Ulysses ##########
|
||||
if ulysses_sp_size > 1:
|
||||
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
Reference in New Issue
Block a user