Files
Search-R1/verl/utils/megatron_utils.py
PeterGriffinJin 068516be64 Initial commit
2025-02-28 15:16:19 +00:00

254 lines
12 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities."""
from typing import Any, Dict
import time
from omegaconf import DictConfig
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.memory_buffer import build_memory_reference_from_module
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import Float16Module
# from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.get_virtual_pipeline_model_parallel_world_size() is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(pre_process=pre_process, post_process=post_process)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert mpu.get_pipeline_model_parallel_split_rank() is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = mpu.get_pipeline_model_parallel_split_rank()
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(pre_process=pre_process, post_process=post_process)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])),
flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
config = get_model_config(model[0])
if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel
model = [Float16Module(config, model_module) for model_module in model]
if wrap_with_ddp:
model = [
DDP(config=config,
module=model_chunk,
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
disable_bucketing=(model_chunk_idx > 0)) for (model_chunk_idx, model_chunk) in enumerate(model)
]
# # Broadcast params from data parallel src rank to other data parallel ranks.
# if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
from transformers import PretrainedConfig
def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
print(f'megatron config {megatron_config}')
dt = PrecisionType.to_dtype(megatron_config['param_dtype'])
print(f'pipeline_dtype=megatron_config {dt}')
transformer_config = TransformerConfig(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_query_groups=hf_config.num_key_value_heads,
ffn_hidden_size=hf_config.intermediate_size,
# max_position_embeddings=hf_config.max_position_embeddings,
activation_func=F.silu,
normalization='RMSNorm',
# rotary_percent=False, # default,
gated_linear_unit=True, # for llama
use_cpu_initialization=True,
apply_residual_connection_post_layernorm=False, # check what's this mean
add_bias_linear=False,
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
pipeline_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
params_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
sequence_parallel=megatron_config['sequence_parallel_enabled'],
variable_seq_lengths=True,
masked_softmax_fusion=True,
bf16=PrecisionType.to_dtype(megatron_config['param_dtype']) is torch.bfloat16)
if torch.distributed.get_rank() == 0:
print(f'tensor_parallel_size={transformer_config.tensor_model_parallel_size} \n \
pipeline_model_parallel_size={transformer_config.pipeline_model_parallel_size} \n \
virtual_pipeline_model_parallel_size={transformer_config.virtual_pipeline_model_parallel_size} \n \
pipeline_dtype={transformer_config.pipeline_dtype} \n \
params_dtype={transformer_config.params_dtype} \n \
sequence_parallel={transformer_config.sequence_parallel} \n \
variable_seq_lengths={transformer_config.variable_seq_lengths} \n \
masked_softmax_fusion={transformer_config.masked_softmax_fusion} \n ')
return transformer_config
# from megatron.core.optimizer import OptimizerConfig
from verl.utils.megatron.optimizer_config import OptimizerConfig
def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
config = OptimizerConfig(
optimizer='adam',
lr=optim_config.get('lr'),
clip_grad=optim_config.get('clip_grad'),
weight_decay=1e-2,
bf16=True,
params_dtype=torch.bfloat16,
use_distributed_optimizer=True,
)
return config
from megatron.core import ModelParallelConfig
def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
# TODO(sgm): check how to disable megatron timers
timers = FakeTimers()
return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'),
pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'),
virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'),
sequence_parallel=config.get('sequence_parallel'),
params_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
pipeline_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
bf16=True,
fp16=False,
timers=timers)
class FakeTimers:
"""Disable All Megatron Timing with FakeTimers"""
def __init__(self):
from megatron.timers import DummyTimer
self.dummy_timer = DummyTimer()
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.dummy_timer
def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None):
if hybrid_engine is not None:
pp_rank = mpu.get_pipeline_model_parallel_rank()
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
buffer.data = buffer.data.to('cpu', non_blocking=True)
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
else:
for module in module_list:
for _, param in module.named_parameters():
param.data = param.data.to('cpu', non_blocking=True)
if offload_grad and param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_megatron_param_and_grad(module_list: nn.ModuleList, device_id, load_grad=False, hybrid_engine=None):
if hybrid_engine is not None:
pp_rank = mpu.get_pipeline_model_parallel_rank()
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
buffer.data = buffer.data.to(device_id, non_blocking=True)
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
else:
for module in module_list:
for _, param in module.named_parameters():
param.data = param.data.to(device_id, non_blocking=True)
if load_grad and param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()