Initial commit
This commit is contained in:
33
verl/workers/sharding_manager/__init__.py
Normal file
33
verl/workers/sharding_manager/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
|
||||
|
||||
from .base import BaseShardingManager
|
||||
from .fsdp_ulysses import FSDPUlyssesShardingManager
|
||||
|
||||
AllGatherPPModel = None
|
||||
|
||||
if is_megatron_core_available() and is_vllm_available():
|
||||
from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
|
||||
elif AllGatherPPModel is not None:
|
||||
pass
|
||||
else:
|
||||
AllGatherPPModel = None
|
||||
MegatronVLLMShardingManager = None
|
||||
|
||||
if is_vllm_available():
|
||||
from .fsdp_vllm import FSDPVLLMShardingManager
|
||||
else:
|
||||
FSDPVLLMShardingManager = None
|
||||
33
verl/workers/sharding_manager/base.py
Normal file
33
verl/workers/sharding_manager/base.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Sharding manager to implement HybridEngine
|
||||
"""
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class BaseShardingManager:
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
return data
|
||||
88
verl/workers/sharding_manager/fsdp_ulysses.py
Normal file
88
verl/workers/sharding_manager/fsdp_ulysses.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
|
||||
"""
|
||||
from typing import Optional
|
||||
from .base import BaseShardingManager
|
||||
|
||||
import random
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl.utils.torch_functional import allgather_dict_tensors
|
||||
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class FSDPUlyssesShardingManager(BaseShardingManager):
|
||||
"""
|
||||
Sharding manager to support data resharding when using FSDP + Ulysses
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
super().__init__()
|
||||
self.device_mesh = device_mesh
|
||||
self.seed_offset = 12345
|
||||
|
||||
def __enter__(self):
|
||||
if self.device_mesh is not None:
|
||||
# We have a global SP group
|
||||
# so we have to change to use model-specific sp group
|
||||
self.prev_sp_group = get_ulysses_sequence_parallel_group()
|
||||
set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
|
||||
# TODO: check how to set seed for each model
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# restore random states
|
||||
if self.device_mesh is not None:
|
||||
# revert to previous sp group
|
||||
set_ulysses_sequence_parallel_group(self.prev_sp_group)
|
||||
# TODO: check how to set seed for each model
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""
|
||||
AllGather data from sp region
|
||||
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
|
||||
In Ulysses, we need to make sure the same data is used across a SP group
|
||||
"""
|
||||
if self.device_mesh is not None:
|
||||
sp_size = self.device_mesh['sp'].size()
|
||||
group = self.device_mesh['sp'].get_group()
|
||||
|
||||
prev_device = data.batch.device
|
||||
data.batch = data.batch.cuda(device=torch.cuda.current_device())
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
|
||||
data.batch = data.batch.to(prev_device)
|
||||
# all gather non_tensor_batch
|
||||
all_non_tensor_batch = [None for _ in range(sp_size)]
|
||||
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
|
||||
data.non_tensor_batch = {
|
||||
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
|
||||
}
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""
|
||||
Split the data to follow FSDP partition
|
||||
"""
|
||||
if self.device_mesh is not None:
|
||||
sp_size = self.device_mesh['sp'].size()
|
||||
sp_rank = self.device_mesh['sp'].get_local_rank()
|
||||
data = data.chunk(chunks=sp_size)[sp_rank]
|
||||
return data
|
||||
133
verl/workers/sharding_manager/fsdp_vllm.py
Normal file
133
verl/workers/sharding_manager/fsdp_vllm.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl.third_party.vllm import LLM
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
|
||||
|
||||
|
||||
class FSDPVLLMShardingManager(BaseShardingManager):
|
||||
|
||||
def __init__(self,
|
||||
module: FSDP,
|
||||
inference_engine: LLM,
|
||||
model_config,
|
||||
full_params: bool = False,
|
||||
device_mesh: DeviceMesh = None):
|
||||
self.module = module
|
||||
self.inference_engine = inference_engine
|
||||
self.model_config = model_config
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
# Full params
|
||||
self.full_params = full_params
|
||||
if full_params:
|
||||
FSDP.set_state_dict_type(self.module,
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=FullStateDictConfig())
|
||||
else:
|
||||
FSDP.set_state_dict_type(self.module,
|
||||
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
||||
state_dict_config=ShardedStateDictConfig())
|
||||
|
||||
# Note that torch_random_states may be different on each dp rank
|
||||
self.torch_random_states = torch.cuda.get_rng_state()
|
||||
# get a random rng states
|
||||
if self.device_mesh is not None:
|
||||
gen_dp_rank = self.device_mesh['dp'].get_local_rank()
|
||||
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
else:
|
||||
self.gen_random_states = None
|
||||
|
||||
def __enter__(self):
|
||||
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
|
||||
params = self.module.state_dict()
|
||||
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
|
||||
# Copy, not share memory
|
||||
load_format = 'hf' if self.full_params else 'dtensor'
|
||||
self.inference_engine.sync_model_weights(params, load_format=load_format)
|
||||
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
|
||||
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
|
||||
|
||||
# TODO: offload FSDP model weights
|
||||
# self.module.cpu()
|
||||
# torch.cuda.empty_cache()
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
|
||||
|
||||
# important: need to manually set the random states of each tp to be identical.
|
||||
if self.device_mesh is not None:
|
||||
self.torch_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.gen_random_states)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger)
|
||||
self.inference_engine.offload_model_weights()
|
||||
log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger)
|
||||
|
||||
# self.module.to('cuda')
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
|
||||
|
||||
self.module.train()
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# restore random states
|
||||
if self.device_mesh is not None:
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=vllm_ps.get_tensor_model_parallel_world_size(),
|
||||
group=vllm_ps.get_tensor_model_parallel_group(),
|
||||
dim=0)
|
||||
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=vllm_ps.get_tensor_model_parallel_src_rank(),
|
||||
group=vllm_ps.get_tensor_model_parallel_group())
|
||||
dp_rank = torch.distributed.get_rank()
|
||||
dp_size = torch.distributed.get_world_size() # not consider torch micro-dp
|
||||
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1:
|
||||
# TODO: shall we build a micro_dp group for vllm when integrating with vLLM?
|
||||
local_prompts = data.chunk(chunks=tp_size)
|
||||
data = local_prompts[dp_rank % tp_size]
|
||||
return data
|
||||
428
verl/workers/sharding_manager/megatron_vllm.py
Normal file
428
verl/workers/sharding_manager/megatron_vllm.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch import nn
|
||||
|
||||
from megatron.core import parallel_state as mpu
|
||||
from megatron.core import DistributedDataParallel as LocalDDP
|
||||
from megatron.core.transformer.module import Float16Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
from verl.utils.megatron_utils import get_model, unwrap_model
|
||||
from verl.utils.memory_buffer import (
|
||||
build_memory_buffer,
|
||||
build_memory_reference_from_module,
|
||||
get_weight_buffer_meta_from_module,
|
||||
)
|
||||
|
||||
|
||||
class AllGatherPPModel:
|
||||
|
||||
def __init__(self, model_provider) -> None:
|
||||
|
||||
self._pp_group = mpu.get_pipeline_model_parallel_group()
|
||||
self._pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
self._pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
|
||||
self._model_chunk_size = self._vpp_size or 1
|
||||
|
||||
# each one holds a list of model_chunks in this pp stage
|
||||
self._pp_models = [None] * self.pp_size
|
||||
|
||||
rank_list = list(range(self.pp_size))
|
||||
# make current rank the last one to initialize
|
||||
rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
|
||||
self._this_rank_models = None
|
||||
|
||||
# store the parameter of each pp stage
|
||||
self.memory_buffers = [None] * self.pp_size
|
||||
for cur_pp_rank in rank_list:
|
||||
print(
|
||||
f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, '
|
||||
f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB')
|
||||
# since the last initialized rank is the current pp rank, after init, the pp rank is still correct
|
||||
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
models = get_model(model_provider, wrap_with_ddp=False)
|
||||
models = nn.ModuleList(models)
|
||||
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
||||
self.pp_models[cur_pp_rank] = models
|
||||
else:
|
||||
# for regular model, we wrapped it with DDP
|
||||
models = get_model(model_provider)
|
||||
assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
|
||||
self._this_rank_models = nn.ModuleList(models)
|
||||
self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))
|
||||
|
||||
self._build_param_buffer(cur_pp_rank)
|
||||
self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)
|
||||
|
||||
# TODO: after binding to the memory buffer, we can load the checkpoint here
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
for model in self.pp_models[cur_pp_rank]:
|
||||
model.eval()
|
||||
self._offload_params_to_cpu(cur_pp_rank)
|
||||
|
||||
def _build_param_buffer(self, pp_rank):
|
||||
"""Build the parameter buffer in each pp rank"""
|
||||
model = self.pp_models[pp_rank]
|
||||
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
|
||||
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
|
||||
|
||||
def _build_param_references(self, pp_rank, maintain_weight=False):
|
||||
model = self.pp_models[pp_rank]
|
||||
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
|
||||
|
||||
def _load_params_to_cuda(self, pp_rank, to_empty=False):
|
||||
assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
|
||||
for buffer in self.memory_buffers[pp_rank].values():
|
||||
if not to_empty:
|
||||
buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
|
||||
else:
|
||||
buffer.data = torch.empty_like(buffer.data, device='cuda')
|
||||
# rebuild reference after loading to CUDA
|
||||
self._build_param_references(pp_rank)
|
||||
|
||||
def _offload_params_to_cpu(self, pp_rank, to_empty=False):
|
||||
assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
|
||||
for buffer in self.memory_buffers[pp_rank].values():
|
||||
if not to_empty:
|
||||
# offload the whole memory buffer to CPU
|
||||
buffer.data = buffer.data.to('cpu', non_blocking=True)
|
||||
else:
|
||||
buffer.data = torch.empty_like(buffer.data, device='cpu')
|
||||
self._build_param_references(pp_rank)
|
||||
|
||||
def load_params_to_cuda(self, to_empty=False):
|
||||
"""load all model params to cuda"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)
|
||||
|
||||
def allgather_params(self):
|
||||
"""allgather params of all pp ranks. Return a list of handles"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
|
||||
|
||||
# NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models
|
||||
for memory_buffer in self.memory_buffers[cur_pp_rank].values():
|
||||
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
try:
|
||||
prev_output = None
|
||||
for cur_chunk_rank in range(self._model_chunk_size):
|
||||
if self._vpp_size:
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)
|
||||
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
|
||||
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
|
||||
ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
|
||||
self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
|
||||
prev_output = ret
|
||||
finally:
|
||||
if self._vpp_size:
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(0)
|
||||
mpu.set_pipeline_model_parallel_rank(self.pp_rank)
|
||||
return ret
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
return self.forward(*inputs, **kwargs)
|
||||
|
||||
def eval(self):
|
||||
for model in self.pp_models[self.pp_rank]:
|
||||
model.eval()
|
||||
|
||||
def train(self):
|
||||
for model in self.pp_models[self.pp_rank]:
|
||||
model.train()
|
||||
|
||||
def offload_params_to_cpu(self, to_empty=False):
|
||||
"""offload params of models that are not of current pp rank to cpu"""
|
||||
for cur_pp_rank in range(self.pp_size):
|
||||
if cur_pp_rank != self.pp_rank:
|
||||
self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)
|
||||
|
||||
def get_all_params(self):
|
||||
"""Get all the parameters of the models in all pp ranks
|
||||
|
||||
Returns:
|
||||
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
|
||||
tensors of each model chunk
|
||||
|
||||
"""
|
||||
params = []
|
||||
for pp_rank in range(self.pp_size):
|
||||
params.append([])
|
||||
for model_chunk_idx in range(len(self.pp_models[pp_rank])):
|
||||
params[pp_rank].append({})
|
||||
pp_model = self.pp_models[pp_rank][model_chunk_idx]
|
||||
pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module
|
||||
for name, param in pp_model.named_parameters():
|
||||
# NOTE(gh) workaround: should not get lora params for inference
|
||||
if 'lora' in name:
|
||||
continue
|
||||
params[pp_rank][model_chunk_idx][name] = param
|
||||
|
||||
return params
|
||||
|
||||
def update_this_rank_models(self, new_models):
|
||||
self._this_rank_models = new_models
|
||||
self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))
|
||||
|
||||
@property
|
||||
def this_rank_models(self):
|
||||
return self._this_rank_models
|
||||
|
||||
@property
|
||||
def pp_size(self):
|
||||
return self._pp_size
|
||||
|
||||
@property
|
||||
def pp_rank(self):
|
||||
return self._pp_rank
|
||||
|
||||
@property
|
||||
def pp_group(self):
|
||||
return self._pp_group
|
||||
|
||||
@property
|
||||
def pp_models(self):
|
||||
return self._pp_models
|
||||
|
||||
|
||||
"""
|
||||
Megatron Hybrid Engine:
|
||||
- During training, only the current pp stage holds the parameters
|
||||
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
|
||||
- Bind the parameters to the inference engine
|
||||
- Do inference in tp. pp is treated as additional dp
|
||||
- After inference, all the parameters that doesn't belong to this pp rank is freed.
|
||||
"""
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed
|
||||
from torch.distributed import new_group
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
|
||||
import verl.utils.megatron.tensor_parallel as tp_utils
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from verl.third_party.vllm import LLM
|
||||
from verl.utils.model import normalize_pp_vpp_params
|
||||
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
|
||||
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
|
||||
_MICRO_DATA_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
class MegatronVLLMShardingManager(BaseShardingManager):
|
||||
|
||||
def __init__(self, module: AllGatherPPModel, inference_engine: LLM, model_config, layer_name_mapping):
|
||||
self.module = module
|
||||
self.inference_engine = inference_engine
|
||||
self.model_config = model_config
|
||||
self.layer_name_mapping = layer_name_mapping
|
||||
|
||||
# initialize micro_dp group for vllm inference
|
||||
global _MICRO_DATA_PARALLEL_GROUP
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
train_tensor_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
infer_tensor_parallel_size = vllm_ps.get_tensor_model_parallel_world_size()
|
||||
|
||||
# TODO(sgm): this may not be true for FSDP -> vLLM
|
||||
assert infer_tensor_parallel_size <= train_tensor_parallel_size, \
|
||||
'Not implemented for infer_tp > train_tp'
|
||||
assert train_tensor_parallel_size % infer_tensor_parallel_size == 0
|
||||
|
||||
micro_dp_size = train_tensor_parallel_size // infer_tensor_parallel_size
|
||||
num_micro_dp_groups = world_size // micro_dp_size
|
||||
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
|
||||
for i in range(num_micro_dp_groups):
|
||||
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
|
||||
group = new_group(ranks=ranks)
|
||||
if rank in ranks:
|
||||
_MICRO_DATA_PARALLEL_GROUP = group
|
||||
|
||||
def default_tp_concat_fn(self, name, param, infer_params, model_config):
|
||||
"""
|
||||
name: name of the parameter
|
||||
param: training parameters
|
||||
infer_params (List[torch.Tensor]): a list of parameters all-gathered from micro_dp_group
|
||||
model_config: huggingface model_config
|
||||
TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
|
||||
definition so that it is model-agnostic. If the model doesn't implement this function,
|
||||
we can throw an error to force user disable TP HybridEngine.
|
||||
"""
|
||||
|
||||
if self.layer_name_mapping.get("qkv_layer_name") in name:
|
||||
# if the tensor is qkv, for each param on tp, split into q, k, v
|
||||
# concat q, k, v separately.
|
||||
q_lst = []
|
||||
k_lst = []
|
||||
v_lst = []
|
||||
assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
|
||||
num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
|
||||
assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
|
||||
kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
|
||||
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
|
||||
for infer_param in infer_params:
|
||||
q, k, v = infer_param.split(split_size)
|
||||
q_lst.append(q)
|
||||
k_lst.append(k)
|
||||
v_lst.append(v)
|
||||
q = torch.cat(q_lst, dim=0)
|
||||
k = torch.cat(k_lst, dim=0)
|
||||
v = torch.cat(v_lst, dim=0)
|
||||
|
||||
infer_params = torch.cat((q, k, v), dim=0)
|
||||
|
||||
elif self.layer_name_mapping.get("gate_proj_layer_name") in name:
|
||||
# if the tensor is gate and proj
|
||||
gate_lst = []
|
||||
up_lst = []
|
||||
for infer_param in infer_params:
|
||||
gate, up = infer_param.chunk(2)
|
||||
gate_lst.append(gate)
|
||||
up_lst.append(up)
|
||||
gate = torch.cat(gate_lst, dim=0)
|
||||
up = torch.cat(up_lst, dim=0)
|
||||
infer_params = torch.cat((gate, up), dim=0)
|
||||
|
||||
else:
|
||||
# concat tensor
|
||||
infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))
|
||||
|
||||
return infer_params
|
||||
|
||||
def _post_process_params(self, params):
|
||||
"""
|
||||
For each param, if it is a tp-splited param, we all-gather from micro_dp group.
|
||||
"""
|
||||
# here the params are in train tp format. we iterate params and all-gather
|
||||
# TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer.
|
||||
# In this way, all the params in the original memory_buffers and can be offload.
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
micro_dp_group = get_micro_data_parallel_group()
|
||||
|
||||
if micro_dp_size <= 1:
|
||||
return
|
||||
|
||||
origin_params = {}
|
||||
for name in params.keys():
|
||||
param = params[name]
|
||||
if tp_utils.is_tensor_parallel_param(param):
|
||||
# allocate a new tensor with proper size
|
||||
infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)]
|
||||
torch.distributed.all_gather(infer_params, param, group=micro_dp_group)
|
||||
infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config)
|
||||
# replace with original param
|
||||
params[name] = infer_params
|
||||
origin_params[name] = param
|
||||
|
||||
return origin_params
|
||||
|
||||
def __enter__(self):
|
||||
# create a new cuda space for parameters not in this pp rank
|
||||
self.module.load_params_to_cuda()
|
||||
# broadcast the parameters from pp rank to other ranks
|
||||
self.module.allgather_params()
|
||||
# obtain name to parameters in pp/vpp
|
||||
params = self.module.get_all_params()
|
||||
|
||||
# bind the params to inference engine
|
||||
self.params = normalize_pp_vpp_params(params=params,
|
||||
num_hidden_layers=self.model_config.num_hidden_layers,
|
||||
layer_name='layers')
|
||||
self.origin_params = self._post_process_params(self.params)
|
||||
self.inference_engine.sync_model_weights(self.params, load_format='megatron')
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# offload parameters doesn't belong to this pp rank
|
||||
self.module.offload_params_to_cpu()
|
||||
|
||||
# FIXME(sgm): the best practice is to delete the cuda tensor
|
||||
# rebind the model weights, can be any cpu tensor
|
||||
if get_micro_data_parallel_world_size() > 1:
|
||||
for name in self.params.keys():
|
||||
self.params[name] = self.origin_params[name]
|
||||
|
||||
# self.inference_engine.sync_model_weights(params)
|
||||
self.inference_engine.offload_model_weights()
|
||||
|
||||
self.module.train()
|
||||
|
||||
# add empty cache after each compute
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# prompts are identical for each training tp. We select for each inference tp
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
micro_dp_rank = get_micro_data_parallel_rank()
|
||||
|
||||
# broadcast from tp=0 to other tp ranks
|
||||
broadcast_dict_tensor(data.batch,
|
||||
src=mpu.get_tensor_model_parallel_src_rank(),
|
||||
group=mpu.get_tensor_model_parallel_group())
|
||||
|
||||
if micro_dp_size > 1:
|
||||
local_prompts = data.chunk(chunks=micro_dp_size)
|
||||
data = local_prompts[micro_dp_rank]
|
||||
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
meta_info = data.meta_info
|
||||
# all gather batch among micro-dp groups
|
||||
micro_dp_size = get_micro_data_parallel_world_size()
|
||||
if micro_dp_size > 1:
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=get_micro_data_parallel_world_size(),
|
||||
group=get_micro_data_parallel_group(),
|
||||
dim=0)
|
||||
|
||||
# all gather batch among pp group
|
||||
if meta_info.get('allgather_pp_output', True):
|
||||
data.batch = allgather_dict_tensors(data.batch.contiguous(),
|
||||
size=mpu.get_pipeline_model_parallel_world_size(),
|
||||
group=mpu.get_pipeline_model_parallel_group(),
|
||||
dim=0)
|
||||
return data
|
||||
|
||||
|
||||
"""
|
||||
Micro Data parallel group
|
||||
"""
|
||||
|
||||
|
||||
def get_micro_data_parallel_group():
|
||||
assert _MICRO_DATA_PARALLEL_GROUP is not None
|
||||
return _MICRO_DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_micro_data_parallel_world_size():
|
||||
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
|
||||
|
||||
|
||||
def get_micro_data_parallel_rank():
|
||||
return torch.distributed.get_rank(group=get_micro_data_parallel_group())
|
||||
Reference in New Issue
Block a user