Files
Search-R1/verl/workers/sharding_manager/megatron_vllm.py
PeterGriffinJin 068516be64 Initial commit
2025-02-28 15:16:19 +00:00

429 lines
18 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
"""
import torch
import torch.distributed as dist
from torch import nn
from megatron.core import parallel_state as mpu
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import get_model, unwrap_model
from verl.utils.memory_buffer import (
build_memory_buffer,
build_memory_reference_from_module,
get_weight_buffer_meta_from_module,
)
class AllGatherPPModel:
def __init__(self, model_provider) -> None:
self._pp_group = mpu.get_pipeline_model_parallel_group()
self._pp_rank = mpu.get_pipeline_model_parallel_rank()
self._pp_size = mpu.get_pipeline_model_parallel_world_size()
self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
self._model_chunk_size = self._vpp_size or 1
# each one holds a list of model_chunks in this pp stage
self._pp_models = [None] * self.pp_size
rank_list = list(range(self.pp_size))
# make current rank the last one to initialize
rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
self._this_rank_models = None
# store the parameter of each pp stage
self.memory_buffers = [None] * self.pp_size
for cur_pp_rank in rank_list:
print(
f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, '
f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB')
# since the last initialized rank is the current pp rank, after init, the pp rank is still correct
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
if cur_pp_rank != self.pp_rank:
models = get_model(model_provider, wrap_with_ddp=False)
models = nn.ModuleList(models)
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
self.pp_models[cur_pp_rank] = models
else:
# for regular model, we wrapped it with DDP
models = get_model(model_provider)
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
self._this_rank_models = nn.ModuleList(models)
self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))
self._build_param_buffer(cur_pp_rank)
self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)
# TODO: after binding to the memory buffer, we can load the checkpoint here
if cur_pp_rank != self.pp_rank:
for model in self.pp_models[cur_pp_rank]:
model.eval()
self._offload_params_to_cpu(cur_pp_rank)
def _build_param_buffer(self, pp_rank):
"""Build the parameter buffer in each pp rank"""
model = self.pp_models[pp_rank]
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
def _build_param_references(self, pp_rank, maintain_weight=False):
model = self.pp_models[pp_rank]
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
def _load_params_to_cuda(self, pp_rank, to_empty=False):
assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
for buffer in self.memory_buffers[pp_rank].values():
if not to_empty:
buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
else:
buffer.data = torch.empty_like(buffer.data, device='cuda')
# rebuild reference after loading to CUDA
self._build_param_references(pp_rank)
def _offload_params_to_cpu(self, pp_rank, to_empty=False):
assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
for buffer in self.memory_buffers[pp_rank].values():
if not to_empty:
# offload the whole memory buffer to CPU
buffer.data = buffer.data.to('cpu', non_blocking=True)
else:
buffer.data = torch.empty_like(buffer.data, device='cpu')
self._build_param_references(pp_rank)
def load_params_to_cuda(self, to_empty=False):
"""load all model params to cuda"""
for cur_pp_rank in range(self.pp_size):
if cur_pp_rank != self.pp_rank:
self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)
def allgather_params(self):
"""allgather params of all pp ranks. Return a list of handles"""
for cur_pp_rank in range(self.pp_size):
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
# NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models
for memory_buffer in self.memory_buffers[cur_pp_rank].values():
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False)
def forward(self, *inputs, **kwargs):
try:
prev_output = None
for cur_chunk_rank in range(self._model_chunk_size):
if self._vpp_size:
mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)
for cur_pp_rank in range(self.pp_size):
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
prev_output = ret
finally:
if self._vpp_size:
mpu.set_virtual_pipeline_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(self.pp_rank)
return ret
def __call__(self, *inputs, **kwargs):
return self.forward(*inputs, **kwargs)
def eval(self):
for model in self.pp_models[self.pp_rank]:
model.eval()
def train(self):
for model in self.pp_models[self.pp_rank]:
model.train()
def offload_params_to_cpu(self, to_empty=False):
"""offload params of models that are not of current pp rank to cpu"""
for cur_pp_rank in range(self.pp_size):
if cur_pp_rank != self.pp_rank:
self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)
def get_all_params(self):
"""Get all the parameters of the models in all pp ranks
Returns:
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
tensors of each model chunk
"""
params = []
for pp_rank in range(self.pp_size):
params.append([])
for model_chunk_idx in range(len(self.pp_models[pp_rank])):
params[pp_rank].append({})
pp_model = self.pp_models[pp_rank][model_chunk_idx]
pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module
for name, param in pp_model.named_parameters():
# NOTE(gh) workaround: should not get lora params for inference
if 'lora' in name:
continue
params[pp_rank][model_chunk_idx][name] = param
return params
def update_this_rank_models(self, new_models):
self._this_rank_models = new_models
self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))
@property
def this_rank_models(self):
return self._this_rank_models
@property
def pp_size(self):
return self._pp_size
@property
def pp_rank(self):
return self._pp_rank
@property
def pp_group(self):
return self._pp_group
@property
def pp_models(self):
return self._pp_models
"""
Megatron Hybrid Engine:
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from .base import BaseShardingManager
import torch
from torch import nn
import torch.distributed
from torch.distributed import new_group
from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
import verl.utils.megatron.tensor_parallel as tp_utils
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.third_party.vllm import LLM
from verl.utils.model import normalize_pp_vpp_params
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
_MICRO_DATA_PARALLEL_GROUP = None
class MegatronVLLMShardingManager(BaseShardingManager):
def __init__(self, module: AllGatherPPModel, inference_engine: LLM, model_config, layer_name_mapping):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.layer_name_mapping = layer_name_mapping
# initialize micro_dp group for vllm inference
global _MICRO_DATA_PARALLEL_GROUP
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
train_tensor_parallel_size = mpu.get_tensor_model_parallel_world_size()
infer_tensor_parallel_size = vllm_ps.get_tensor_model_parallel_world_size()
# TODO(sgm): this may not be true for FSDP -> vLLM
assert infer_tensor_parallel_size <= train_tensor_parallel_size, \
'Not implemented for infer_tp > train_tp'
assert train_tensor_parallel_size % infer_tensor_parallel_size == 0
micro_dp_size = train_tensor_parallel_size // infer_tensor_parallel_size
num_micro_dp_groups = world_size // micro_dp_size
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
for i in range(num_micro_dp_groups):
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
group = new_group(ranks=ranks)
if rank in ranks:
_MICRO_DATA_PARALLEL_GROUP = group
def default_tp_concat_fn(self, name, param, infer_params, model_config):
"""
name: name of the parameter
param: training parameters
infer_params (List[torch.Tensor]): a list of parameters all-gathered from micro_dp_group
model_config: huggingface model_config
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
definition so that it is model-agnostic. If the model doesn't implement this function,
we can throw an error to force user disable TP HybridEngine.
"""
if self.layer_name_mapping.get("qkv_layer_name") in name:
# if the tensor is qkv, for each param on tp, split into q, k, v
# concat q, k, v separately.
q_lst = []
k_lst = []
v_lst = []
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
for infer_param in infer_params:
q, k, v = infer_param.split(split_size)
q_lst.append(q)
k_lst.append(k)
v_lst.append(v)
q = torch.cat(q_lst, dim=0)
k = torch.cat(k_lst, dim=0)
v = torch.cat(v_lst, dim=0)
infer_params = torch.cat((q, k, v), dim=0)
elif self.layer_name_mapping.get("gate_proj_layer_name") in name:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in infer_params:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
infer_params = torch.cat((gate, up), dim=0)
else:
# concat tensor
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))
return infer_params
def _post_process_params(self, params):
"""
For each param, if it is a tp-splited param, we all-gather from micro_dp group.
"""
# here the params are in train tp format. we iterate params and all-gather
# TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer.
# In this way, all the params in the original memory_buffers and can be offload.
micro_dp_size = get_micro_data_parallel_world_size()
micro_dp_group = get_micro_data_parallel_group()
if micro_dp_size <= 1:
return
origin_params = {}
for name in params.keys():
param = params[name]
if tp_utils.is_tensor_parallel_param(param):
# allocate a new tensor with proper size
infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)]
torch.distributed.all_gather(infer_params, param, group=micro_dp_group)
infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config)
# replace with original param
params[name] = infer_params
origin_params[name] = param
return origin_params
def __enter__(self):
# create a new cuda space for parameters not in this pp rank
self.module.load_params_to_cuda()
# broadcast the parameters from pp rank to other ranks
self.module.allgather_params()
# obtain name to parameters in pp/vpp
params = self.module.get_all_params()
# bind the params to inference engine
self.params = normalize_pp_vpp_params(params=params,
num_hidden_layers=self.model_config.num_hidden_layers,
layer_name='layers')
self.origin_params = self._post_process_params(self.params)
self.inference_engine.sync_model_weights(self.params, load_format='megatron')
def __exit__(self, exc_type, exc_value, traceback):
# offload parameters doesn't belong to this pp rank
self.module.offload_params_to_cpu()
# FIXME(sgm): the best practice is to delete the cuda tensor
# rebind the model weights, can be any cpu tensor
if get_micro_data_parallel_world_size() > 1:
for name in self.params.keys():
self.params[name] = self.origin_params[name]
# self.inference_engine.sync_model_weights(params)
self.inference_engine.offload_model_weights()
self.module.train()
# add empty cache after each compute
torch.cuda.empty_cache()
def preprocess_data(self, data: DataProto) -> DataProto:
# prompts are identical for each training tp. We select for each inference tp
micro_dp_size = get_micro_data_parallel_world_size()
micro_dp_rank = get_micro_data_parallel_rank()
# broadcast from tp=0 to other tp ranks
broadcast_dict_tensor(data.batch,
src=mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if micro_dp_size > 1:
local_prompts = data.chunk(chunks=micro_dp_size)
data = local_prompts[micro_dp_rank]
return data
def postprocess_data(self, data: DataProto) -> DataProto:
meta_info = data.meta_info
# all gather batch among micro-dp groups
micro_dp_size = get_micro_data_parallel_world_size()
if micro_dp_size > 1:
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=get_micro_data_parallel_world_size(),
group=get_micro_data_parallel_group(),
dim=0)
# all gather batch among pp group
if meta_info.get('allgather_pp_output', True):
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=mpu.get_pipeline_model_parallel_world_size(),
group=mpu.get_pipeline_model_parallel_group(),
dim=0)
return data
"""
Micro Data parallel group
"""
def get_micro_data_parallel_group():
assert _MICRO_DATA_PARALLEL_GROUP is not None
return _MICRO_DATA_PARALLEL_GROUP
def get_micro_data_parallel_world_size():
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
def get_micro_data_parallel_rank():
return torch.distributed.get_rank(group=get_micro_data_parallel_group())