Initial commit
This commit is contained in:
13
verl/models/llama/__init__.py
Normal file
13
verl/models/llama/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
24
verl/models/llama/megatron/__init__.py
Normal file
24
verl/models/llama/megatron/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .modeling_llama_megatron import (
|
||||
# original model with megatron
|
||||
ParallelLlamaModel,
|
||||
ParallelLlamaForCausalLM,
|
||||
# rmpad with megatron
|
||||
ParallelLlamaForCausalLMRmPad,
|
||||
ParallelLlamaForValueRmPad,
|
||||
# rmpad with megatron and pipeline parallelism
|
||||
ParallelLlamaForCausalLMRmPadPP,
|
||||
ParallelLlamaForValueRmPadPP)
|
||||
13
verl/models/llama/megatron/checkpoint_utils/__init__.py
Normal file
13
verl/models/llama/megatron/checkpoint_utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
446
verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Normal file
446
verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import time
|
||||
from typing import Dict, Any, Callable, Optional
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def _megatron_calc_layer_map(config):
|
||||
"""Calculate the mapping of global layer_idx to local layer_idx
|
||||
Returns:
|
||||
layer_map (Dict: int -> tuple(int, int, int)):
|
||||
mapping from the global layer index to
|
||||
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
||||
"""
|
||||
import megatron
|
||||
from megatron.core import mpu
|
||||
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||
|
||||
layer_map = dict()
|
||||
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||
|
||||
for pp_rank_idx in range(pp_size):
|
||||
for virtual_pp_rank_idx in range(virtual_pp_size):
|
||||
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
|
||||
pp_rank_idx * num_layers_per_model)
|
||||
for layer_idx in range(num_layers_per_model):
|
||||
layer_map[layer_offset + layer_idx] = (
|
||||
pp_rank_idx,
|
||||
virtual_pp_rank_idx,
|
||||
layer_idx,
|
||||
)
|
||||
return layer_map
|
||||
|
||||
|
||||
def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
|
||||
"""Load merged state_dict to sharded Megatron module in training.
|
||||
"""
|
||||
import megatron
|
||||
from megatron.core import mpu
|
||||
from megatron.utils import print_rank_0, unwrap_model
|
||||
from megatron.core.transformer.module import Float16Module
|
||||
from megatron.core import DistributedDataParallel as LocalDDP
|
||||
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
def _get_gpt_model(model):
|
||||
return model
|
||||
|
||||
def broadcast_params(module):
|
||||
for param in module.parameters():
|
||||
torch.distributed.broadcast(param.data,
|
||||
src=mpu.get_data_parallel_src_rank(),
|
||||
group=mpu.get_data_parallel_group())
|
||||
|
||||
dp_rank = mpu.get_data_parallel_rank()
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||
mp_group = mpu.get_model_parallel_group()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||
|
||||
models = [None] * len(wrapped_models)
|
||||
|
||||
for i, wrapped_model in enumerate(wrapped_models):
|
||||
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
||||
gpt_model_module = _get_gpt_model(models[i])
|
||||
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
||||
|
||||
def _broadcast_tensor(tensor, name) -> torch.Tensor:
|
||||
"""broadcast tensor from rank0 across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if name in state_dict:
|
||||
weight = state_dict[name]
|
||||
tensor_shape = weight.shape
|
||||
else:
|
||||
tensor_shape = None
|
||||
else:
|
||||
weight = None
|
||||
tensor_shape = None
|
||||
|
||||
obj_list = [tensor_shape]
|
||||
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||
tensor_shape = obj_list[0]
|
||||
|
||||
if tensor_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
|
||||
return
|
||||
|
||||
if tensor is None:
|
||||
tensor = torch.empty(
|
||||
tensor_shape,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
tensor.data.copy_(weight)
|
||||
dist.broadcast(tensor, src=0, group=mp_group)
|
||||
|
||||
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if name in state_dict:
|
||||
full_weight = state_dict[name]
|
||||
|
||||
if mutate_func is not None:
|
||||
full_weight = mutate_func(full_weight)
|
||||
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
||||
chunk_shape = tensor_chunk[0].shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||
return
|
||||
|
||||
if tensor is None:
|
||||
sync_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
assert (tensor.shape == chunk_shape
|
||||
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
||||
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||
|
||||
for i in range(tp_size):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
sync_tensor.data.copy_(tensor_chunk[i])
|
||||
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||
if (i == tp_rank) and (tensor is not None):
|
||||
tensor.data.copy_(sync_tensor)
|
||||
|
||||
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if name in state_dict:
|
||||
full_weight = state_dict[name]
|
||||
if mutate_func is not None:
|
||||
full_weight = mutate_func(full_weight)
|
||||
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
||||
chunk_shape = tensor_chunk[0].shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||
return
|
||||
|
||||
if tensor is None:
|
||||
sync_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
assert (tensor.shape == chunk_shape
|
||||
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
||||
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||
|
||||
for i in range(tp_size):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
sync_tensor.data.copy_(tensor_chunk[i])
|
||||
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||
if (i == tp_rank) and (tensor is not None):
|
||||
tensor.data.copy_(sync_tensor)
|
||||
|
||||
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
gate_weight = state_dict[gate_name]
|
||||
up_weight = state_dict[up_name]
|
||||
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
|
||||
config.hidden_size,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
for i in range(tp_size):
|
||||
intermediate_size_tp = config.intermediate_size // tp_size
|
||||
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
|
||||
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
|
||||
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
|
||||
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
||||
|
||||
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
||||
chunk_shape = tensor_chunk[0].shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
|
||||
return
|
||||
|
||||
if tensor is None:
|
||||
sync_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
tensor.shape == chunk_shape
|
||||
), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
|
||||
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||
|
||||
for i in range(tp_size):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
sync_tensor.data.copy_(tensor_chunk[i])
|
||||
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||
if (i == tp_rank) and (tensor is not None):
|
||||
tensor.data.copy_(sync_tensor)
|
||||
|
||||
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
|
||||
full_weight_q = state_dict[q_name]
|
||||
full_weight_k = state_dict[k_name]
|
||||
full_weight_v = state_dict[v_name]
|
||||
|
||||
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
||||
|
||||
if config.num_key_value_heads >= tp_size:
|
||||
q_size_tp = config.hidden_size // tp_size
|
||||
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
||||
total_size = q_size_tp + 2 * kv_size_tp
|
||||
new_weight_qkv = torch.empty(total_size * tp_size,
|
||||
config.hidden_size,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
for i in range(tp_size):
|
||||
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
|
||||
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
|
||||
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
|
||||
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
|
||||
dim=0))
|
||||
|
||||
else:
|
||||
q_size_tp = config.hidden_size // tp_size
|
||||
kv_size_tp = hidden_size_per_head
|
||||
total_size = q_size_tp + 2 * kv_size_tp
|
||||
new_weight_qkv = torch.empty(total_size * tp_size,
|
||||
config.hidden_size,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
for i in range(tp_size):
|
||||
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
|
||||
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
||||
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
||||
k_part = full_weight_k[start_idx:end_idx]
|
||||
v_part = full_weight_v[start_idx:end_idx]
|
||||
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
|
||||
dim=0))
|
||||
|
||||
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
||||
chunk_shape = tensor_chunk[0].shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||
return
|
||||
|
||||
if tensor is None:
|
||||
sync_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
assert (tensor.shape == chunk_shape
|
||||
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
|
||||
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||
|
||||
for i in range(tp_size):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
sync_tensor.data.copy_(tensor_chunk[i])
|
||||
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||
if (i == tp_rank) and (tensor is not None):
|
||||
tensor.data.copy_(sync_tensor)
|
||||
|
||||
if dp_rank == 0:
|
||||
# Embeddings
|
||||
# -------------------
|
||||
print_rank_0("loading embeddings...")
|
||||
gpt_model_module = _get_gpt_model(models[0])
|
||||
embed_tokens_weight = None
|
||||
if pp_rank == 0:
|
||||
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
||||
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
||||
|
||||
# Transformer layers
|
||||
# -------------------
|
||||
layer_map = _megatron_calc_layer_map(config)
|
||||
|
||||
for layer in range(config.num_hidden_layers):
|
||||
print_rank_0(f"loading layer #{layer}...")
|
||||
layer_name = f"model.layers.{layer}"
|
||||
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
||||
|
||||
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
||||
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
||||
|
||||
_broadcast_tensor(
|
||||
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.input_layernorm.weight",
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor_qkv(
|
||||
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.self_attn.q_proj.weight",
|
||||
f"{layer_name}.self_attn.k_proj.weight",
|
||||
f"{layer_name}.self_attn.v_proj.weight",
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor(
|
||||
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.self_attn.o_proj.weight",
|
||||
chunk_dim=1,
|
||||
)
|
||||
|
||||
_broadcast_tensor(
|
||||
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.post_attention_layernorm.weight",
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
|
||||
|
||||
_broadcast_tp_shard_tensor(
|
||||
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
||||
f"{layer_name}.mlp.down_proj.weight",
|
||||
chunk_dim=1,
|
||||
)
|
||||
# Final Layernorm
|
||||
# -------------------
|
||||
print_rank_0("loading final layernorm...")
|
||||
gpt_model_module = _get_gpt_model(models[-1])
|
||||
_broadcast_tensor(
|
||||
getattr(gpt_model_module.model.norm, "weight", None),
|
||||
"model.norm.weight",
|
||||
)
|
||||
|
||||
print_rank_0("loading lm_head...")
|
||||
lm_head_weight = None
|
||||
if pp_rank + 1 == pp_size:
|
||||
lm_head_weight = gpt_model_module.lm_head.weight
|
||||
|
||||
if is_value_model:
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
|
||||
_broadcast_tensor(lm_head_weight, "lm_head.weight")
|
||||
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
|
||||
_broadcast_tensor(lm_head_weight, "reward_head.weight")
|
||||
print_rank_0('load lm_head from value_head weight')
|
||||
else:
|
||||
_broadcast_tensor(None, "lm_head.weight")
|
||||
print_rank_0('fail to match lm_head in value_model')
|
||||
# else:
|
||||
|
||||
# _broadcast_tensor(lm_head_weight, "lm_head.weight")
|
||||
|
||||
else:
|
||||
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
||||
dist.barrier()
|
||||
# Broadcast weights inside data parallel groups
|
||||
for wrapped_model in wrapped_models:
|
||||
broadcast_params(wrapped_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
||||
449
verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Normal file
449
verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import megatron
|
||||
from megatron.core import mpu
|
||||
from megatron.utils import print_rank_0, unwrap_model
|
||||
from megatron.model import Float16Module
|
||||
from megatron.model import DistributedDataParallel as LocalDDP
|
||||
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
||||
import torch
|
||||
import time
|
||||
from typing import Optional
|
||||
import torch.distributed as dist
|
||||
from megatron import get_args
|
||||
|
||||
|
||||
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
|
||||
"""given TP,DP,PP rank to get the global rank."""
|
||||
|
||||
args = get_args()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
dp_size = mpu.get_data_parallel_world_size()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
|
||||
), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
|
||||
if args.switch_dp_and_pp_grouping:
|
||||
# TP-PP-DP grouping
|
||||
return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank
|
||||
else:
|
||||
# TP-DP-PP grouping
|
||||
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
|
||||
|
||||
|
||||
def _megatron_calc_layer_map(config):
|
||||
"""Calculate the mapping of global layer_idx to local layer_idx
|
||||
Returns:
|
||||
layer_map (Dict: int -> tuple(int, int, int)):
|
||||
mapping from the global layer index to
|
||||
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
||||
"""
|
||||
import megatron
|
||||
from megatron.core import mpu
|
||||
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||
|
||||
args = megatron.get_args()
|
||||
layer_map = dict()
|
||||
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||
|
||||
for pp_rank_idx in range(pp_size):
|
||||
for virtual_pp_rank_idx in range(virtual_pp_size):
|
||||
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
|
||||
pp_rank_idx * num_layers_per_model)
|
||||
for layer_idx in range(num_layers_per_model):
|
||||
layer_map[layer_offset + layer_idx] = (
|
||||
pp_rank_idx,
|
||||
virtual_pp_rank_idx,
|
||||
layer_idx,
|
||||
)
|
||||
return layer_map
|
||||
|
||||
|
||||
def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'):
|
||||
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
|
||||
|
||||
Args:
|
||||
wrapped_modelss (list of megatron.model.DistributedDataParallel):
|
||||
The local DDP wrapped megatron modules.
|
||||
dtype (str or None):
|
||||
The data type of state_dict. if None, the data type of the original parameters
|
||||
is used.
|
||||
gpt_model_key: key to access model
|
||||
Returns:
|
||||
state_dict (dict):
|
||||
The merged state_dict in rank 0, and an empty dictionary in other ranks.
|
||||
"""
|
||||
start_time = time.time()
|
||||
args = megatron.get_args()
|
||||
|
||||
def _get_gpt_model(model):
|
||||
return model
|
||||
|
||||
dp_rank = mpu.get_data_parallel_rank()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||
mp_group = mpu.get_model_parallel_group()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||
|
||||
models = [None] * len(wrapped_models)
|
||||
|
||||
for i, wrapped_model in enumerate(wrapped_models):
|
||||
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
||||
assert len(models[i].model.layers
|
||||
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
|
||||
len(models[i].model.layers), num_layers_per_model)
|
||||
|
||||
state_dict = dict()
|
||||
|
||||
def _get_cpu_tensor(tensor: torch.Tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
if tensor.device == torch.device("cpu"):
|
||||
return tensor.detach().clone()
|
||||
return tensor.detach().cpu()
|
||||
|
||||
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
|
||||
"""broadcast tensor across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||
|
||||
if torch.distributed.get_rank() == src_rank:
|
||||
if tensor is None:
|
||||
weight = None
|
||||
tensor_shape = None
|
||||
else:
|
||||
weight = tensor
|
||||
tensor_shape = weight.shape
|
||||
else:
|
||||
weight = None
|
||||
tensor_shape = None
|
||||
|
||||
obj_list = [tensor_shape]
|
||||
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||
tensor_shape = obj_list[0]
|
||||
|
||||
if tensor_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tensor:[{name}] not exist, skip collect")
|
||||
return
|
||||
|
||||
if weight is None:
|
||||
weight = torch.empty(
|
||||
tensor_shape,
|
||||
dtype=args.params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
dist.broadcast(weight, src=src_rank, group=mp_group)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[name] = _get_cpu_tensor(weight)
|
||||
|
||||
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||
|
||||
if torch.distributed.get_rank() == src_rank:
|
||||
chunk_shape = tensor.shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
|
||||
return
|
||||
|
||||
buffer_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=args.params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
chunk_tensors = [None] * tp_size
|
||||
|
||||
for i in range(tp_size):
|
||||
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
|
||||
if mutate_func is not None:
|
||||
full_tensor = mutate_func(full_tensor)
|
||||
state_dict[name] = full_tensor
|
||||
|
||||
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||
|
||||
if torch.distributed.get_rank() == src_rank:
|
||||
chunk_shape = tensor.shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
|
||||
return
|
||||
|
||||
buffer_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=args.params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
chunk_tensors = [None] * tp_size
|
||||
|
||||
for i in range(tp_size):
|
||||
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
full_tensor = torch.concat(chunk_tensors, dim=0)
|
||||
intermediate_size_tp = config.intermediate_size // tp_size
|
||||
gate_weight_list = []
|
||||
up_weight_list = []
|
||||
for i in range(tp_size):
|
||||
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
|
||||
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
|
||||
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
|
||||
gate_weight_list.append(gate_weight_tp)
|
||||
up_weight_list.append(up_weight_tp)
|
||||
|
||||
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
|
||||
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
|
||||
|
||||
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
|
||||
"""broadcast tensor in tp shards across mp_group"""
|
||||
nonlocal state_dict
|
||||
nonlocal mp_group
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||
|
||||
if torch.distributed.get_rank() == src_rank:
|
||||
chunk_shape = tensor.shape
|
||||
else:
|
||||
chunk_shape = None
|
||||
|
||||
obj_list = [chunk_shape]
|
||||
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||
chunk_shape = obj_list[0]
|
||||
if chunk_shape is None:
|
||||
# all or none ranks in the mp_group should reach here
|
||||
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
|
||||
return
|
||||
|
||||
buffer_tensor = torch.empty(
|
||||
chunk_shape,
|
||||
dtype=args.params_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
chunk_tensors = [None] * tp_size
|
||||
|
||||
for i in range(tp_size):
|
||||
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
full_tensor = torch.concat(chunk_tensors, dim=0)
|
||||
q_weight_list = []
|
||||
k_weight_list = []
|
||||
v_weight_list = []
|
||||
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
||||
|
||||
if config.num_key_value_heads >= tp_size:
|
||||
q_size_tp = config.hidden_size // tp_size
|
||||
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
||||
total_size = q_size_tp + 2 * kv_size_tp
|
||||
for i in range(tp_size):
|
||||
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
|
||||
q_part = qkv_part[:q_size_tp]
|
||||
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
|
||||
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
|
||||
q_weight_list.append(q_part)
|
||||
k_weight_list.append(k_part)
|
||||
v_weight_list.append(v_part)
|
||||
else:
|
||||
q_size_tp = config.hidden_size // tp_size
|
||||
kv_size_tp = hidden_size_per_head
|
||||
total_size = q_size_tp + 2 * kv_size_tp
|
||||
for i in range(tp_size):
|
||||
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
|
||||
q_part = qkv_part[:q_size_tp]
|
||||
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
|
||||
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
|
||||
q_weight_list.append(q_part)
|
||||
if i * config.num_key_value_heads % tp_size == 0:
|
||||
k_weight_list.append(k_part)
|
||||
v_weight_list.append(v_part)
|
||||
|
||||
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
|
||||
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
|
||||
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
|
||||
|
||||
# empty cache before collecting weights
|
||||
torch.cuda.empty_cache()
|
||||
# Embeddings
|
||||
# -------------------
|
||||
if dp_rank == 0:
|
||||
# Embeddings
|
||||
# -------------------
|
||||
print_rank_0("collecting embeddings...")
|
||||
gpt_model_module = _get_gpt_model(models[0])
|
||||
_broadcast_tp_shard_tensor(
|
||||
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
|
||||
"model.embed_tokens.weight",
|
||||
src_pp_rank=0,
|
||||
)
|
||||
|
||||
# Transformer layers
|
||||
# -------------------
|
||||
layer_map = _megatron_calc_layer_map(config)
|
||||
for layer in range(config.num_hidden_layers):
|
||||
print_rank_0(f"collecting layer #{layer}...")
|
||||
layer_name = f"model.layers.{layer}"
|
||||
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
|
||||
|
||||
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
|
||||
sync_layer = gpt_model_module.model.layers[src_layer_idx]
|
||||
|
||||
_broadcast_tensor(
|
||||
sync_layer.input_layernorm.weight,
|
||||
f"{layer_name}.input_layernorm.weight",
|
||||
src_pp_rank=src_pp_rank,
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor_qkv(
|
||||
sync_layer.self_attn.qkv_proj.weight,
|
||||
f"{layer_name}.self_attn.q_proj.weight",
|
||||
f"{layer_name}.self_attn.k_proj.weight",
|
||||
f"{layer_name}.self_attn.v_proj.weight",
|
||||
src_pp_rank=src_pp_rank,
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor(
|
||||
sync_layer.self_attn.o_proj.weight,
|
||||
f"{layer_name}.self_attn.o_proj.weight",
|
||||
concat_dim=1,
|
||||
src_pp_rank=src_pp_rank,
|
||||
)
|
||||
|
||||
_broadcast_tensor(
|
||||
sync_layer.post_attention_layernorm.weight,
|
||||
f"{layer_name}.post_attention_layernorm.weight",
|
||||
src_pp_rank=src_pp_rank,
|
||||
)
|
||||
|
||||
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
|
||||
f"{layer_name}.mlp.gate_proj.weight",
|
||||
f"{layer_name}.mlp.up_proj.weight",
|
||||
src_pp_rank=src_pp_rank)
|
||||
|
||||
_broadcast_tp_shard_tensor(
|
||||
sync_layer.mlp.down_proj.weight,
|
||||
f"{layer_name}.mlp.down_proj.weight",
|
||||
concat_dim=1,
|
||||
src_pp_rank=src_pp_rank,
|
||||
)
|
||||
|
||||
# Final Layernorm
|
||||
# -------------------
|
||||
print_rank_0("collecting final layernorm...")
|
||||
gpt_model_module = _get_gpt_model(models[-1])
|
||||
_broadcast_tensor(
|
||||
getattr(gpt_model_module.model.norm, "weight", None),
|
||||
"model.norm.weight",
|
||||
src_pp_rank=pp_size - 1,
|
||||
)
|
||||
|
||||
print_rank_0("collecting lm_head...")
|
||||
|
||||
if is_value_model:
|
||||
_broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
||||
"reward_head.weight",
|
||||
src_pp_rank=pp_size - 1)
|
||||
|
||||
else:
|
||||
_broadcast_tp_shard_tensor(
|
||||
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
||||
"lm_head.weight",
|
||||
src_pp_rank=pp_size - 1,
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype is None or dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
print(f'Unknown/unsupported dtype to save: {dtype}"')
|
||||
exit(1)
|
||||
for k, v in state_dict.items():
|
||||
if dtype != v.dtype:
|
||||
state_dict[k] = v.to(dtype)
|
||||
|
||||
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
|
||||
return state_dict
|
||||
18
verl/models/llama/megatron/layers/__init__.py
Normal file
18
verl/models/llama/megatron/layers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .parallel_attention import ParallelLlamaAttention
|
||||
from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
|
||||
from .parallel_mlp import ParallelLlamaMLP
|
||||
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
||||
418
verl/models/llama/megatron/layers/parallel_attention.py
Normal file
418
verl/models/llama/megatron/layers/parallel_attention.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import tensor_parallel
|
||||
from megatron.core import ModelParallelConfig
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
|
||||
|
||||
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
||||
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class ParallelLlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
# assign values after tp
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
|
||||
assert self.num_key_value_heads % tp_size == 0, \
|
||||
f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'
|
||||
|
||||
self.num_heads_per_tp = self.num_heads // tp_size
|
||||
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
|
||||
self.hidden_size_per_tp = self.hidden_size // tp_size
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads}).")
|
||||
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
||||
|
||||
if megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
||||
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
||||
|
||||
# [self.q_size, self.k_size, self.v_size]
|
||||
self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
head_dim=self.head_dim,
|
||||
bias=config.attention_bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
self.q_size = self.num_heads_per_tp * self.head_dim
|
||||
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
|
||||
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
|
||||
|
||||
self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=False,
|
||||
**row_kwargs)
|
||||
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
qkv = self.qkv_proj(hidden_states)[0]
|
||||
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
|
||||
attn_output = self.o_proj(attn_output)[0]
|
||||
return attn_output
|
||||
|
||||
|
||||
"""
|
||||
Remove padding Attention
|
||||
- Using Flash-attn 2
|
||||
- Compatible with sequence parallel
|
||||
"""
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
|
||||
batch_size = position_ids.shape[0]
|
||||
|
||||
q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
|
||||
k = pad_input(k, indices, batch_size, sequence_length)
|
||||
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
||||
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
|
||||
k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
|
||||
# use flash-attn rotary embeddings with rmpad
|
||||
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
|
||||
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
|
||||
q_embed = apply_rotary_emb(q,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen)
|
||||
k_embed = apply_rotary_emb(k,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
max_seqlen_in_batch: int = None):
|
||||
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
|
||||
|
||||
if self.megatron_config.sequence_parallel:
|
||||
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)[0]
|
||||
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
|
||||
dim=-1) # (total_nnz, 1, hidden_size)
|
||||
|
||||
if self.megatron_config.sequence_parallel:
|
||||
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
|
||||
total_nnz = cu_seqlens[-1] # total_nnz before sp padding
|
||||
query_states = query_states[:total_nnz]
|
||||
key_states = key_states[:total_nnz]
|
||||
value_states = value_states[:total_nnz]
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dime x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
|
||||
key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
||||
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
|
||||
cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half
|
||||
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
|
||||
key_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch)
|
||||
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
|
||||
|
||||
# TODO: llama does not have dropout in the config??
|
||||
# It is recommended to use dropout with FA according to the docs
|
||||
# when training.
|
||||
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query_states = query_states.to(torch.float16)
|
||||
key_states = key_states.to(torch.float16)
|
||||
value_states = value_states.to(torch.float16)
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen_in_batch,
|
||||
max_seqlen_k=max_seqlen_in_batch,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_output_unpad = attn_output_unpad.to(input_dtype)
|
||||
attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
|
||||
|
||||
# sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
|
||||
# Here we need to repad
|
||||
if self.megatron_config.sequence_parallel:
|
||||
attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
|
||||
|
||||
attn_output_unpad = self.o_proj(attn_output_unpad)[0]
|
||||
return attn_output_unpad
|
||||
146
verl/models/llama/megatron/layers/parallel_decoder.py
Normal file
146
verl/models/llama/megatron/layers/parallel_decoder.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from megatron.core import ModelParallelConfig
|
||||
|
||||
from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
|
||||
from .parallel_mlp import ParallelLlamaMLP
|
||||
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
||||
|
||||
|
||||
class ParallelLlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)
|
||||
|
||||
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
||||
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Note: sequence parallel is hidden inside ColumnParallelLinear
|
||||
# reduce scatter is hidden inside RowParallelLinear
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# TODO: add sequence parallel operator reduce_scatter here
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
# TODO: add sequence parallel operator all_gather here
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO: add sequence parallel operator reduce_scatter here
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = hidden_states
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ParallelLlamaDecoderLayerRmPad(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)
|
||||
|
||||
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
||||
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
|
||||
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
# shape changes same as attn
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = hidden_states
|
||||
|
||||
return outputs
|
||||
74
verl/models/llama/megatron/layers/parallel_linear.py
Normal file
74
verl/models/llama/megatron/layers/parallel_linear.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from megatron.core import tensor_parallel
|
||||
|
||||
|
||||
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
num_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
*,
|
||||
bias=True,
|
||||
gather_output=True,
|
||||
skip_bias_add=False,
|
||||
**kwargs):
|
||||
# Keep input parameters, and already restrict the head numbers
|
||||
self.input_size = input_size
|
||||
self.q_output_size = num_heads * head_dim
|
||||
self.kv_output_size = num_key_value_heads * head_dim
|
||||
self.head_dim = head_dim
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
input_size = self.input_size
|
||||
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
|
||||
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
gate_ouput_size,
|
||||
up_output_size,
|
||||
*,
|
||||
bias=True,
|
||||
gather_output=True,
|
||||
skip_bias_add=False,
|
||||
**kwargs):
|
||||
# Keep input parameters, and already restrict the head numbers
|
||||
self.input_size = input_size
|
||||
self.output_size = gate_ouput_size + up_output_size
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
super().__init__(input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
**kwargs)
|
||||
74
verl/models/llama/megatron/layers/parallel_mlp.py
Normal file
74
verl/models/llama/megatron/layers/parallel_mlp.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import tensor_parallel
|
||||
from megatron.core import ModelParallelConfig
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
|
||||
|
||||
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||
|
||||
|
||||
class ParallelLlamaMLP(nn.Module):
|
||||
|
||||
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
|
||||
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
||||
|
||||
if megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
||||
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
gate_ouput_size=self.intermediate_size,
|
||||
up_output_size=self.intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs,
|
||||
)
|
||||
self.gate_size = self.intermediate_size // tp_size
|
||||
|
||||
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
|
||||
output_size=self.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=False,
|
||||
**row_kwargs)
|
||||
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)[0]
|
||||
gate, up = gate_up.split(self.gate_size, dim=-1)
|
||||
return self.down_proj(self.act_fn(gate) * up)[0]
|
||||
46
verl/models/llama/megatron/layers/parallel_rmsnorm.py
Normal file
46
verl/models/llama/megatron/layers/parallel_rmsnorm.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import torch
|
||||
from megatron.core import ModelParallelConfig
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
|
||||
|
||||
class ParallelLlamaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(config.hidden_size, numbers.Integral):
|
||||
normalized_shape = (config.hidden_size,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
|
||||
self.variance_epsilon = config.rms_norm_eps
|
||||
|
||||
if megatron_config.sequence_parallel:
|
||||
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return fused_rms_norm_affine(input=hidden_states,
|
||||
weight=self.weight,
|
||||
normalized_shape=self.normalized_shape,
|
||||
eps=self.variance_epsilon,
|
||||
memory_efficient=True)
|
||||
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
@@ -0,0 +1,656 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch LLaMA model with Megatron-style acceleration."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from megatron.core import tensor_parallel
|
||||
from megatron.core import ModelParallelConfig
|
||||
from torch import nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
|
||||
|
||||
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
|
||||
"""
|
||||
TODO:
|
||||
1. Add weight initialization. Here we need to be careful on TP weight init.
|
||||
2. Add sequence parallel
|
||||
3. Load checkpoint from meta LLama pretrained checkpoint
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class ParallelLlamaModel(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
||||
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
||||
combined_attention_mask)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (batch_size, seq_length)
|
||||
attention_mask: attention_mask. shape (batch_size, seq_length)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
|
||||
output_size=config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = outputs
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
|
||||
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
|
||||
|
||||
logits = logits.float()
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
class ParallelLlamaModelRmPad(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
self.megatron_config = megatron_config
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (1, totol_nnz)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||
|
||||
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||
if self.megatron_config.sequence_parallel:
|
||||
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(hidden_states,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLMRmPad(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self._init_head()
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||
output_size=self.config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
# all_gather from sequence parallel region is performed inside lm_head
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
# remove padding here
|
||||
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||
attention_mask) # (total_nnz, 1)
|
||||
|
||||
# pad input_ids to multiple of tp for all tp ranks
|
||||
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||
if self.megatron_config.sequence_parallel:
|
||||
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
|
||||
|
||||
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
|
||||
|
||||
outputs = self.model(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = outputs
|
||||
|
||||
logits = self._forward_head(hidden_states)
|
||||
|
||||
# remove padding from sequence parallel
|
||||
if self.megatron_config.sequence_parallel:
|
||||
totol_nnz = cu_seqlens[-1]
|
||||
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||
|
||||
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
|
||||
# add removed padding back
|
||||
logits = pad_input(logits, indices, batch_size,
|
||||
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||
# lm_head is effectively the same as sequence parallel
|
||||
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||
logits = logits.float()
|
||||
if self.megatron_config.sequence_parallel:
|
||||
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output = super().forward(input_ids, attention_mask, position_ids)
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
"""
|
||||
Support pipeline parallelism
|
||||
"""
|
||||
|
||||
|
||||
class ParallelLlamaModelRmPadPP(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
This model definition supports pipeline parallelism. To support pp and vpp,
|
||||
- This model only contains layer in this pp stage and vpp chunk
|
||||
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.megatron_config = megatron_config
|
||||
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||
if megatron_config is not None:
|
||||
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||
if pre_process:
|
||||
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
**embedding_kwargs)
|
||||
else:
|
||||
self.embed_tokens = None
|
||||
|
||||
# pp_rank = megatron_config.pipeline_model_parallel_rank
|
||||
pp_size = megatron_config.pipeline_model_parallel_size
|
||||
self.num_layer_per_pp = config.num_hidden_layers // pp_size
|
||||
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
|
||||
|
||||
if vpp_size is not None:
|
||||
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
|
||||
self.num_layer_this_model = self.num_layer_vpp_chunk
|
||||
# vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank
|
||||
# self.offset = vpp_rank * (
|
||||
# config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \
|
||||
# (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk)
|
||||
else:
|
||||
self.num_layer_this_model = self.num_layer_per_pp
|
||||
# self.offset = pp_rank * self.num_layer_per_pp
|
||||
|
||||
layers = []
|
||||
for i in range(self.num_layer_this_model):
|
||||
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config)
|
||||
# setattr(layer, 'hidden_layer_index', self.offset + i)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
if post_process:
|
||||
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
sequence_length: int = None,
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids: input ids. shape (1, totol_nnz)
|
||||
position_ids: position ids. shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.pre_process:
|
||||
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||
|
||||
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
|
||||
# so need to deal with it by handle here:
|
||||
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||
if self.megatron_config.sequence_parallel:
|
||||
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
# self.hidden_states should be passed by Megatron
|
||||
hidden_states = self.input_tensor
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
layer_outputs = decoder_layer(hidden_states,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if self.post_process:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.megatron_config = megatron_config
|
||||
self.model = ParallelLlamaModelRmPadPP(config,
|
||||
megatron_config=megatron_config,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process)
|
||||
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
if post_process:
|
||||
self._init_head()
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
assert len(input_tensor) == 1
|
||||
self.model.set_input_tensor(input_tensor[0])
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||
output_size=self.config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
**column_kwargs)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
# all_gather from sequence parallel region is performed inside lm_head
|
||||
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
|
||||
logits = self.lm_head(hidden_states)[0]
|
||||
# logits shape after forward_head logits.shape: [8, 32, 8]
|
||||
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# original input
|
||||
*,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
```"""
|
||||
|
||||
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
|
||||
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
# remove padding here
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||
attention_mask) # (total_nnz, 1)
|
||||
|
||||
# pad input_ids to multiple of tp for all tp ranks
|
||||
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||
if self.megatron_config.sequence_parallel:
|
||||
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
|
||||
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
|
||||
|
||||
outputs = self.model(input_ids=input_ids_rmpad,
|
||||
position_ids=position_ids,
|
||||
sequence_length=sequence_length,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||
|
||||
if self.post_process:
|
||||
hidden_states = outputs
|
||||
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
|
||||
logits = self._forward_head(hidden_states)
|
||||
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
|
||||
|
||||
# remove padding from sequence parallel
|
||||
if self.megatron_config.sequence_parallel:
|
||||
totol_nnz = cu_seqlens[-1]
|
||||
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||
# add removed padding back. If input is already rmpad, we let the caller pad_input
|
||||
logits = pad_input(logits, indices, batch_size,
|
||||
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
|
||||
|
||||
def _init_head(self):
|
||||
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||
if self.megatron_config is not None:
|
||||
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||
# lm_head is effectively the same as sequence parallel
|
||||
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||
|
||||
def _forward_head(self, hidden_states):
|
||||
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||
logits = logits.float()
|
||||
if self.megatron_config.sequence_parallel:
|
||||
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
if self.post_process:
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
else:
|
||||
return output
|
||||
Reference in New Issue
Block a user