Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

27
verl/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, 'version/version')) as f:
__version__ = f.read().strip()
from .protocol import DataProto
from .utils.logging_utils import set_basic_config
import logging
set_basic_config(level=logging.WARNING)

35
verl/models/README.md Normal file
View File

@@ -0,0 +1,35 @@
# Models
Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
## Adding a New Huggingface Model
### Step 1: Copy the model file from HF to verl
- Add a new file under verl/models/hf
- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
### Step 2: Modify the model file to use packed inputs
- Remove all the code related to inference (kv cache)
- Modify the inputs to include only
- input_ids (total_nnz,)
- cu_seqlens (total_nnz + 1,)
- max_seqlen_in_batch: int
- Note that this requires using flash attention with causal mask.
### Step 2.5: Add tests
- Add a test to compare this version and the huggingface version
- Following the infrastructure and add tests to tests/models/hf
### Step 3: Add a function to apply tensor parallelism
- Please follow
- https://pytorch.org/docs/stable/distributed.tensor.parallel.html
- https://pytorch.org/tutorials/intermediate/TP_tutorial.html
- General comments
- Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
### Step 4: Add a function to apply data parallelism
- Please use FSDP2 APIs
- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
### Step 5: Add a function to apply pipeline parallelism
- Comes in Pytorch 2.4
- Currently only in alpha in nightly version
- Check torchtitan for more details

13
verl/models/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,24 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_llama_megatron import (
# original model with megatron
ParallelLlamaModel,
ParallelLlamaForCausalLM,
# rmpad with megatron
ParallelLlamaForCausalLMRmPad,
ParallelLlamaForValueRmPad,
# rmpad with megatron and pipeline parallelism
ParallelLlamaForCausalLMRmPadPP,
ParallelLlamaForValueRmPadPP)

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,446 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
import megatron
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
import megatron
from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
start_time = time.time()
def _get_gpt_model(model):
return model
def broadcast_params(module):
for param in module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if torch.distributed.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
gpt_model_module = _get_gpt_model(models[i])
assert len(gpt_model_module.model.layers) == num_layers_per_model
def _broadcast_tensor(tensor, name) -> torch.Tensor:
"""broadcast tensor from rank0 across mp_group"""
nonlocal state_dict
nonlocal mp_group
if torch.distributed.get_rank() == 0:
if name in state_dict:
weight = state_dict[name]
tensor_shape = weight.shape
else:
tensor_shape = None
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
return
if tensor is None:
tensor = torch.empty(
tensor_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
if torch.distributed.get_rank() == 0:
tensor.data.copy_(weight)
dist.broadcast(tensor, src=0, group=mp_group)
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (
tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
if dp_rank == 0:
# Embeddings
# -------------------
print_rank_0("loading embeddings...")
gpt_model_module = _get_gpt_model(models[0])
embed_tokens_weight = None
if pp_rank == 0:
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"loading layer #{layer}...")
layer_name = f"model.layers.{layer}"
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
_broadcast_tensor(
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
)
_broadcast_tp_shard_tensor(
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.o_proj.weight",
chunk_dim=1,
)
_broadcast_tensor(
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.post_attention_layernorm.weight",
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.down_proj.weight",
chunk_dim=1,
)
# Final Layernorm
# -------------------
print_rank_0("loading final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.model.norm, "weight", None),
"model.norm.weight",
)
print_rank_0("loading lm_head...")
lm_head_weight = None
if pp_rank + 1 == pp_size:
lm_head_weight = gpt_model_module.lm_head.weight
if is_value_model:
# if torch.distributed.get_rank() == 0:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "lm_head.weight")
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
# else:
# _broadcast_tensor(lm_head_weight, "lm_head.weight")
else:
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
dist.barrier()
# Broadcast weights inside data parallel groups
for wrapped_model in wrapped_models:
broadcast_params(wrapped_model)
torch.cuda.empty_cache()
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")

View File

@@ -0,0 +1,449 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import megatron
from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model
from megatron.model import Float16Module
from megatron.model import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
from typing import Optional
import torch.distributed as dist
from megatron import get_args
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
"""given TP,DP,PP rank to get the global rank."""
args = get_args()
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
if args.switch_dp_and_pp_grouping:
# TP-PP-DP grouping
return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank
else:
# TP-DP-PP grouping
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
import megatron
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
args = megatron.get_args()
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'):
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
Args:
wrapped_modelss (list of megatron.model.DistributedDataParallel):
The local DDP wrapped megatron modules.
dtype (str or None):
The data type of state_dict. if None, the data type of the original parameters
is used.
gpt_model_key: key to access model
Returns:
state_dict (dict):
The merged state_dict in rank 0, and an empty dictionary in other ranks.
"""
start_time = time.time()
args = megatron.get_args()
def _get_gpt_model(model):
return model
dp_rank = mpu.get_data_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
pp_rank = mpu.get_pipeline_model_parallel_rank()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if dist.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
assert len(models[i].model.layers
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
len(models[i].model.layers), num_layers_per_model)
state_dict = dict()
def _get_cpu_tensor(tensor: torch.Tensor):
if tensor is None:
return None
if tensor.device == torch.device("cpu"):
return tensor.detach().clone()
return tensor.detach().cpu()
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor across mp_group"""
nonlocal state_dict
nonlocal mp_group
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
if tensor is None:
weight = None
tensor_shape = None
else:
weight = tensor
tensor_shape = weight.shape
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not exist, skip collect")
return
if weight is None:
weight = torch.empty(
tensor_shape,
dtype=args.params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
dist.broadcast(weight, src=src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
state_dict[name] = _get_cpu_tensor(weight)
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=args.params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
if mutate_func is not None:
full_tensor = mutate_func(full_tensor)
state_dict[name] = full_tensor
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=args.params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_list = []
up_weight_list = []
for i in range(tp_size):
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
gate_weight_list.append(gate_weight_tp)
up_weight_list.append(up_weight_tp)
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=args.params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
q_weight_list = []
k_weight_list = []
v_weight_list = []
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
q_weight_list.append(q_part)
k_weight_list.append(k_part)
v_weight_list.append(v_part)
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
q_weight_list.append(q_part)
if i * config.num_key_value_heads % tp_size == 0:
k_weight_list.append(k_part)
v_weight_list.append(v_part)
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
# empty cache before collecting weights
torch.cuda.empty_cache()
# Embeddings
# -------------------
if dp_rank == 0:
# Embeddings
# -------------------
print_rank_0("collecting embeddings...")
gpt_model_module = _get_gpt_model(models[0])
_broadcast_tp_shard_tensor(
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
"model.embed_tokens.weight",
src_pp_rank=0,
)
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"collecting layer #{layer}...")
layer_name = f"model.layers.{layer}"
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
sync_layer = gpt_model_module.model.layers[src_layer_idx]
_broadcast_tensor(
sync_layer.input_layernorm.weight,
f"{layer_name}.input_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.weight,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor(
sync_layer.self_attn.o_proj.weight,
f"{layer_name}.self_attn.o_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
_broadcast_tensor(
sync_layer.post_attention_layernorm.weight,
f"{layer_name}.post_attention_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank)
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight,
f"{layer_name}.mlp.down_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
# Final Layernorm
# -------------------
print_rank_0("collecting final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.model.norm, "weight", None),
"model.norm.weight",
src_pp_rank=pp_size - 1,
)
print_rank_0("collecting lm_head...")
if is_value_model:
_broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
"reward_head.weight",
src_pp_rank=pp_size - 1)
else:
_broadcast_tp_shard_tensor(
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1,
)
dist.barrier()
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype is None or dtype == "fp32":
dtype = torch.float32
else:
print(f'Unknown/unsupported dtype to save: {dtype}"')
exit(1)
for k, v in state_dict.items():
if dtype != v.dtype:
state_dict[k] = v.to(dtype)
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
return state_dict

View File

@@ -0,0 +1,18 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .parallel_attention import ParallelLlamaAttention
from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
from .parallel_mlp import ParallelLlamaMLP
from .parallel_rmsnorm import ParallelLlamaRMSNorm

View File

@@ -0,0 +1,418 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class ParallelLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# assign values after tp
tp_size = mpu.get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
assert self.num_key_value_heads % tp_size == 0, \
f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'
self.num_heads_per_tp = self.num_heads // tp_size
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
self.hidden_size_per_tp = self.hidden_size // tp_size
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
# [self.q_size, self.k_size, self.v_size]
self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
bias=config.attention_bias,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.q_size = self.num_heads_per_tp * self.head_dim
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim,
output_size=self.hidden_size,
bias=config.attention_bias,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
attn_output = self.o_proj(attn_output)[0]
return attn_output
"""
Remove padding Attention
- Using Flash-attn 2
- Compatible with sequence parallel
"""
from transformers.utils import is_flash_attn_2_available
import torch.nn.functional as F
from einops import rearrange
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
batch_size = position_ids.shape[0]
q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
k = pad_input(k, indices, batch_size, sequence_length)
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
return q_embed, k_embed
from flash_attn.layers.rotary import apply_rotary_emb
# use flash-attn rotary embeddings with rmpad
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
q_embed = apply_rotary_emb(q,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
k_embed = apply_rotary_emb(k,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
return q_embed, k_embed
class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
def forward(self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: torch.Tensor = None,
max_seqlen_in_batch: int = None):
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
if self.megatron_config.sequence_parallel:
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
dim=-1) # (total_nnz, 1, hidden_size)
if self.megatron_config.sequence_parallel:
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
total_nnz = cu_seqlens[-1] # total_nnz before sp padding
query_states = query_states[:total_nnz]
key_states = key_states[:total_nnz]
value_states = value_states[:total_nnz]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
key_states,
cos,
sin,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch)
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen_in_batch,
max_seqlen_k=max_seqlen_in_batch,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
attn_output_unpad = attn_output_unpad.to(input_dtype)
attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
# sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
# Here we need to repad
if self.megatron_config.sequence_parallel:
attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
attn_output_unpad = self.o_proj(attn_output_unpad)[0]
return attn_output_unpad

View File

@@ -0,0 +1,146 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from megatron.core import ModelParallelConfig
from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
from .parallel_mlp import ParallelLlamaMLP
from .parallel_rmsnorm import ParallelLlamaRMSNorm
class ParallelLlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Note: sequence parallel is hidden inside ColumnParallelLinear
# reduce scatter is hidden inside RowParallelLinear
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
# TODO: add sequence parallel operator reduce_scatter here
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# TODO: add sequence parallel operator all_gather here
hidden_states = self.mlp(hidden_states)
# TODO: add sequence parallel operator reduce_scatter here
hidden_states = residual + hidden_states
outputs = hidden_states
return outputs
class ParallelLlamaDecoderLayerRmPad(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.hidden_size = config.hidden_size
self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
hidden_states = self.self_attn(hidden_states=hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = residual + hidden_states
# Fully Connected
# shape changes same as attn
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = hidden_states
return outputs

View File

@@ -0,0 +1,74 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
from typing import Optional, Tuple
from megatron.core import tensor_parallel
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
num_heads,
num_key_value_heads,
head_dim,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.q_output_size = num_heads * head_dim
self.kv_output_size = num_key_value_heads * head_dim
self.head_dim = head_dim
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
input_size = self.input_size
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
gate_ouput_size,
up_output_size,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.output_size = gate_ouput_size + up_output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
super().__init__(input_size=self.input_size,
output_size=self.output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)

View File

@@ -0,0 +1,74 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class ParallelLlamaMLP(nn.Module):
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_size = mpu.get_tensor_model_parallel_world_size()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
gate_ouput_size=self.intermediate_size,
up_output_size=self.intermediate_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
self.gate_size = self.intermediate_size // tp_size
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
output_size=self.hidden_size,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up = self.gate_up_proj(x)[0]
gate, up = gate_up.split(self.gate_size, dim=-1)
return self.down_proj(self.act_fn(gate) * up)[0]

View File

@@ -0,0 +1,46 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils
class ParallelLlamaRMSNorm(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
if isinstance(config.hidden_size, numbers.Integral):
normalized_shape = (config.hidden_size,)
self.normalized_shape = torch.Size(normalized_shape)
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
self.variance_epsilon = config.rms_norm_eps
if megatron_config.sequence_parallel:
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
def forward(self, hidden_states):
return fused_rms_norm_affine(input=hidden_states,
weight=self.weight,
normalized_shape=self.normalized_shape,
eps=self.variance_epsilon,
memory_efficient=True)

View File

@@ -0,0 +1,656 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model with Megatron-style acceleration."""
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron import tensor_parallel as tp_utils
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
"""
TODO:
1. Add weight initialization. Here we need to be careful on TP weight init.
2. Add sequence parallel
3. Load checkpoint from meta LLama pretrained checkpoint
"""
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class ParallelLlamaModel(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (batch_size, seq_length)
attention_mask: attention_mask. shape (batch_size, seq_length)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
batch_size, seq_length = input_ids.shape
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
self.vocab_size = config.vocab_size
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = outputs
logits = self.lm_head(hidden_states)[0]
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
logits = logits.float()
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
class ParallelLlamaModelRmPad(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
self.megatron_config = megatron_config
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (1, totol_nnz)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLMRmPad(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
self.vocab_size = config.vocab_size
self._init_head()
def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
output_size=self.config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
logits = self.lm_head(hidden_states)[0]
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
return logits
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = outputs
logits = self._forward_head(hidden_states)
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids, attention_mask, position_ids)
output.logits = torch.squeeze(output.logits, dim=-1)
return output
"""
Support pipeline parallelism
"""
class ParallelLlamaModelRmPadPP(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
This model definition supports pipeline parallelism. To support pp and vpp,
- This model only contains layer in this pp stage and vpp chunk
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
self.megatron_config = megatron_config
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
if pre_process:
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
else:
self.embed_tokens = None
# pp_rank = megatron_config.pipeline_model_parallel_rank
pp_size = megatron_config.pipeline_model_parallel_size
self.num_layer_per_pp = config.num_hidden_layers // pp_size
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
if vpp_size is not None:
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
self.num_layer_this_model = self.num_layer_vpp_chunk
# vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank
# self.offset = vpp_rank * (
# config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \
# (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk)
else:
self.num_layer_this_model = self.num_layer_per_pp
# self.offset = pp_rank * self.num_layer_per_pp
layers = []
for i in range(self.num_layer_this_model):
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config)
# setattr(layer, 'hidden_layer_index', self.offset + i)
layers.append(layer)
self.layers = nn.ModuleList(layers)
if post_process:
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
else:
self.norm = None
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (1, totol_nnz)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
if self.pre_process:
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
# so need to deal with it by handle here:
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
hidden_states = inputs_embeds
else:
# self.hidden_states should be passed by Megatron
hidden_states = self.input_tensor
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = layer_outputs
if self.post_process:
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPadPP(config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
if post_process:
self._init_head()
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
assert len(input_tensor) == 1
self.model.set_input_tensor(input_tensor[0])
def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
output_size=self.config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
logits = self.lm_head(hidden_states)[0]
# logits shape after forward_head logits.shape: [8, 32, 8]
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
return logits
def forward(
self,
# original input
*,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids_rmpad,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
if self.post_process:
hidden_states = outputs
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
logits = self._forward_head(hidden_states)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
else:
return outputs
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
def _init_head(self):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits
def forward(
self,
*,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
if self.post_process:
output.logits = torch.squeeze(output.logits, dim=-1)
return output
else:
return output

66
verl/models/registry.py Normal file
View File

@@ -0,0 +1,66 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import List, Optional, Type
import torch.nn as nn
# Supported models using HF Rmpad
# TODO(sgm): HF may supported more than listed here, we should add more after testing
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}
def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
if not model_type in _REOVEPAD_MODELS.keys():
raise ValueError(f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
f"Please set `use_remove_padding=False` in the model config.")
# Supported models in Megatron-LM
# Architecture -> (module, class).
_MODELS = {
"LlamaForCausalLM":
("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")),
"MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP",
"ParallelMistralForCausalLMRmPad"))
}
# return model class
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
megatron = "megatron"
module_name, model_cls_name = _MODELS[model_arch]
if not value: # actor/ref
model_cls_name = model_cls_name[0]
elif value: # critic/rm
model_cls_name = model_cls_name[1]
module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron")
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,145 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from typing import Optional, List, Union, Tuple, Unpack, Callable
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__)
def llama_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
adapt from transformers 4.47.1
"""
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# trade off: repeat first and then all to all
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory.")
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View File

@@ -0,0 +1,74 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply monkey-patch function to models
"""
#### Open Source Models
#### transformers version < 4.48
def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
_PATCH_NAME_TO_FUNC = {
'llama': apply_monkey_patch_to_llama,
'qwen2': apply_monkey_patch_to_qwen2,
}
from transformers import PretrainedConfig
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.")
success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]()
success_apply_monkey_patch = True
if success_apply_monkey_patch and verbose:
print(f'Applying monkey patch to model {config.model_type}')
elif not success_apply_monkey_patch:
raise NotImplementedError(f'Ulysses for model {config.model_type} is not implemented, \
please set `ulysses_sequence_parallel_size=1`')
return success_apply_monkey_patch
from functools import lru_cache
from packaging import version
import importlib.metadata
@lru_cache()
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError:
raise ModuleNotFoundError("The `transformers` package is not installed.")
# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)

View File

@@ -0,0 +1,137 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from typing import Optional, Tuple
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__)
def qwen2_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory.")
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
self.layer_idx >= self.config.max_window_layers):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
# use full_q_len to reshape
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View File

@@ -0,0 +1,23 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def get_weight_loader(arch: str):
from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama
_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {'LlamaForCausalLM': load_state_dict_to_megatron_llama}
if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:
return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}")

639
verl/protocol.py Normal file
View File

@@ -0,0 +1,639 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement base data transfer protocol between any two functions, modules.
We can subclass Protocol to define more detailed batch info with specific keys
"""
import pickle
import numpy as np
import copy
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Union
import torch
import tensordict
from tensordict import TensorDict
from torch.utils.data import DataLoader, Dataset
from verl.utils.py_functional import union_two_dict
__all__ = ['DataProto', 'union_tensor_dict']
try:
tensordict.set_lazy_legacy(False).set()
except:
pass
def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
"""Pad a DataProto to size divisible by size_divisor
Args:
size_divisor (int): size divisor
Returns:
data: (DataProto): the padded DataProto
pad_size (int)
"""
assert isinstance(data, DataProto), 'data must be a DataProto'
if len(data) % size_divisor != 0:
pad_size = size_divisor - len(data) % size_divisor
data_padded = DataProto.concat([data, data[:pad_size]])
else:
pad_size = 0
data_padded = data
return data_padded, pad_size
def unpad_dataproto(data: 'DataProto', pad_size):
if pad_size != 0:
data = data[:-pad_size]
return data
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
"""Union two tensordicts."""
assert tensor_dict1.batch_size == tensor_dict2.batch_size, \
f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}'
for key in tensor_dict2.keys():
if key not in tensor_dict1.keys():
tensor_dict1[key] = tensor_dict2[key]
else:
assert tensor_dict1[key].equal(tensor_dict2[key]), \
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
return tensor_dict1
def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]:
for key, val in tensor_dict2.items():
if key in tensor_dict1:
assert isinstance(tensor_dict2[key], np.ndarray)
assert isinstance(tensor_dict1[key], np.ndarray)
assert np.all(tensor_dict2[key] == tensor_dict1[key]), \
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
tensor_dict1[key] = val
return tensor_dict1
def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
if len(list_of_dict) == 0:
return {}
keys = list_of_dict[0].keys()
output = {key: [] for key in keys}
for data in list_of_dict:
for key, item in data.items():
assert key in output
output[key].append(item)
return output
def fold_batch_dim(data: 'DataProto', new_batch_size):
"""
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
"""
batch_size = data.batch.batch_size[0]
assert batch_size % new_batch_size == 0
tensor: TensorDict = data.batch
non_tensor = data.non_tensor_batch
tensor = tensor.view(new_batch_size, -1)
tensor.auto_batch_size_(batch_dims=1)
for key, val in non_tensor.items():
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
def unfold_batch_dim(data: 'DataProto', batch_dims=2):
"""
Unfold the first n dims as new batch dim
"""
tensor: TensorDict = data.batch
non_tensor = data.non_tensor_batch
tensor.auto_batch_size_(batch_dims=batch_dims)
tensor = tensor.view(-1)
batch_size = tensor.batch_size[0]
non_tensor_new = {}
for key, val in non_tensor.items():
non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))
return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)
def collate_fn(x: list['DataProtoItem']):
batch = []
non_tensor_batch = []
for data in x:
batch.append(data.batch)
non_tensor_batch.append(data.non_tensor_batch)
batch = torch.stack(batch).contiguous()
non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.array(val, dtype=object)
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
@dataclass
class DataProtoItem:
# TODO(zhangchi.usc1992) add consistency check
batch: TensorDict = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)
@dataclass
class DataProto:
"""
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
same batch size should be put inside batch.
"""
batch: TensorDict = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)
def __post_init__(self):
# perform necessary checking
self.check_consistency()
def __len__(self):
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
else:
return 0
def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self):
import io
buffer = io.BytesIO()
if tensordict.__version__ >= '0.5.0' and self.batch is not None:
self.batch = self.batch.contiguous()
self.batch = self.batch.consolidate()
torch.save(self.batch, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info
def __setstate__(self, data):
import io
batch_deserialized_bytes, non_tensor_batch, meta_info = data
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(batch_deserialized,
weights_only=False,
map_location='cpu' if not torch.cuda.is_available() else None)
self.batch = batch
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info
def save_to_disk(self, filepath):
with open(filepath, 'wb') as f:
pickle.dump(self, f)
@staticmethod
def load_from_disk(filepath) -> 'DataProto':
with open(filepath, 'rb') as f:
data = pickle.load(f)
return data
def print_size(self, prefix=""):
size_of_tensordict = 0
for key, tensor in self.batch.items():
size_of_tensordict += tensor.element_size() * tensor.numel()
size_of_numpy_array = 0
for key, numpy_array in self.non_tensor_batch.items():
size_of_numpy_array += numpy_array.nbytes
size_of_numpy_array /= 1024**3
size_of_tensordict /= 1024**3
message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB'
if prefix:
message = f'{prefix}, ' + message
print(message)
def check_consistency(self):
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
We expose this function as a public one so that user can call themselves directly
"""
if self.batch is not None:
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
if self.non_tensor_batch is not None:
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
if self.batch is not None and len(self.non_tensor_batch) != 0:
# TODO: we can actually lift this restriction if needed
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
batch_size = self.batch.batch_size[0]
for key, val in self.non_tensor_batch.items():
assert isinstance(
val, np.ndarray
) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object'
assert val.shape[
0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}'
@classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
tensors = {}
non_tensors = {}
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key] = val
elif isinstance(val, np.ndarray):
non_tensors[key] = val
else:
raise ValueError(f'Unsupported type in data {type(val)}')
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
@classmethod
def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
"""Create a DataProto from a dict of tensors. This assumes that
1. All the tensor in tensors have the same dim0
2. Only dim0 is the batch dim
"""
assert len(tensors) > 0, 'tensors must not be empty'
assert num_batch_dims > 0, 'num_batch_dims must be greater than zero'
if non_tensors is not None:
assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.'
if meta_info is None:
meta_info = {}
if non_tensors is None:
non_tensors = {}
assert isinstance(non_tensors, dict)
# get and check batch size
batch_size = None
pivot_key = None
for key, tensor in tensors.items():
if batch_size is None:
batch_size = tensor.shape[:num_batch_dims]
pivot_key = key
else:
current_batch = tensor.shape[:num_batch_dims]
assert batch_size == current_batch, \
f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}'
for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
def to(self, device) -> 'DataProto':
"""move the batch to device
Args:
device (torch.device, str): torch device
Returns:
DataProto: the current DataProto
"""
if self.batch is not None:
self.batch = self.batch.to(device)
return self
def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
"""Select a subset of the DataProto via batch_keys and meta_info_keys
Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to select
meta_info_keys (list, optional): a list of keys indicating the meta info to select
Returns:
DataProto: the DataProto with the selected batch_keys and meta_info_keys
"""
# TODO (zhangchi.usc1992) whether to copy
if batch_keys is not None:
batch_keys = tuple(batch_keys)
sub_batch = self.batch.select(*batch_keys)
else:
sub_batch = self.batch
if non_tensor_batch_keys is not None:
non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
else:
non_tensor_batch = self.non_tensor_batch
if deepcopy:
non_tensor_batch = copy.deepcopy(non_tensor_batch)
if meta_info_keys is not None:
sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
else:
sub_meta_info = self.meta_info
if deepcopy:
sub_meta_info = copy.deepcopy(sub_meta_info)
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to pop
meta_info_keys (list, optional): a list of keys indicating the meta info to pop
Returns:
DataProto: the DataProto with the poped batch_keys and meta_info_keys
"""
assert batch_keys is not None
if meta_info_keys is None:
meta_info_keys = []
if non_tensor_batch_keys is None:
non_tensor_batch_keys = []
tensors = {}
# tensor batch
for key in batch_keys:
assert key in self.batch.keys()
tensors[key] = self.batch.pop(key)
non_tensors = {}
# non tensor batch
for key in non_tensor_batch_keys:
assert key in self.non_tensor_batch.keys()
non_tensors[key] = self.non_tensor_batch.pop(key)
meta_info = {}
for key in meta_info_keys:
assert key in self.meta_info.keys()
meta_info[key] = self.meta_info.pop(key)
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
"""
Note that this function only rename the key in the batch
"""
def validate_input(keys):
if keys is not None:
if isinstance(keys, str):
keys = [keys]
elif isinstance(keys, list):
pass
else:
raise TypeError(f'keys must be a list or a string, but got {type(keys)}')
return keys
old_keys = validate_input(old_keys)
new_keys = validate_input(new_keys)
if len(new_keys) != len(old_keys):
raise ValueError(
f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}')
self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
return self
def union(self, other: 'DataProto') -> 'DataProto':
"""Union with another DataProto. Union batch and meta_info separately.
Throw an error if
- there are conflict keys in batch and they are not equal
- the batch size of two data batch is not the same
- there are conflict keys in meta_info and they are not the same.
Args:
other (DataProto): another DataProto to union
Returns:
DataProto: the DataProto after union
"""
self.batch = union_tensor_dict(self.batch, other.batch)
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
self.meta_info = union_two_dict(self.meta_info, other.meta_info)
return self
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
Args:
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
``batch.batch_size[0] % mini_batch_size == 0``
epochs (int): number of epochs when iterating the dataset.
dataloader_kwargs: internally, it returns a DataLoader over the batch.
The dataloader_kwargs is the kwargs passed to the DataLoader
Returns:
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
``self.batch.batch_size * epochs // mini_batch_size``
"""
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
# we can directly create a dataloader from TensorDict
if dataloader_kwargs is None:
dataloader_kwargs = {}
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = None
assert isinstance(dataloader_kwargs, Dict)
train_dataloader = DataLoader(dataset=self,
batch_size=mini_batch_size,
collate_fn=collate_fn,
generator=generator,
**dataloader_kwargs)
def get_data():
for _ in range(epochs):
for d in train_dataloader:
d.meta_info = self.meta_info
yield d
return iter(get_data())
def chunk(self, chunks: int) -> List['DataProto']:
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
Args:
chunks (int): the number of chunks to split on dim=0
Returns:
List[DataProto]: a list of DataProto after splitting
"""
assert len(
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else:
batch_lst = [None for _ in range(chunks)]
non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
non_tensor_lst = np.array_split(val, chunks)
assert len(non_tensor_lst) == chunks
for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
output = []
for i in range(chunks):
output.append(
DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))
return output
@staticmethod
def concat(data: List['DataProto']) -> 'DataProto':
"""Concat a list of DataProto. The batch is concatenated among dim=0.
The meta_info is assumed to be identical and will use the first one.
Args:
data (List[DataProto]): list of DataProto
Returns:
DataProto: concatenated DataProto
"""
batch_lst = []
for batch in data:
batch_lst.append(batch.batch)
if batch_lst[0] is not None:
new_batch = torch.cat(batch_lst, dim=0)
else:
new_batch = None
non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.concatenate(val, axis=0)
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
def reorder(self, indices):
"""
Note that this operation is in-place
"""
indices_np = indices.detach().numpy()
self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
def repeat(self, repeat_times=2, interleave=True):
"""
Repeat the batch data a specified number of times.
Args:
repeat_times (int): Number of times to repeat the data.
interleave (bool): Whether to interleave the repeated data.
Returns:
DataProto: A new DataProto with repeated data.
"""
if self.batch is not None:
if interleave:
# Interleave the data
repeated_tensors = {
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
}
else:
# Stack the data
repeated_tensors = {
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
for key, tensor in self.batch.items()
}
repeated_batch = TensorDict(
source=repeated_tensors,
batch_size=(self.batch.batch_size[0] * repeat_times,),
)
else:
repeated_batch = None
repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
if interleave:
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
return DataProto(
batch=repeated_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)
import ray
@dataclass
class DataProtoFuture:
"""
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
for data so that asynchronous execution becomes possible.
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
- collect_fn is a Callable that reduces the list of futures to a DataProto
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
operation on the DataProtoFuture in driver.
"""
collect_fn: Callable
futures: List[ray.ObjectRef]
dispatch_fn: Callable = None
@staticmethod
def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
return output
def chunk(self, chunks: int) -> List['DataProtoFuture']:
from functools import partial
arg_future_lst = []
for i in range(chunks):
# note that we can't directly pass i and chunks
def dispatch_fn(x, i, chunks):
return x.chunk(chunks=chunks)[i]
arg_future = DataProtoFuture(collect_fn=self.collect_fn,
dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks),
futures=self.futures)
arg_future_lst.append(arg_future)
return arg_future_lst
def get(self):
output = ray.get(self.futures) # dp_size.
for o in output:
assert isinstance(o, DataProto)
output = self.collect_fn(output) # select dp, concat
if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp
return output

View File

@@ -0,0 +1,20 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, 'version/version')) as f:
__version__ = f.read().strip()

View File

@@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .worker import Worker
from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool

View File

@@ -0,0 +1,410 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from functools import wraps
from typing import Dict, List, Tuple
from types import FunctionType
from verl.protocol import DataProtoFuture
# here we add a magic number of avoid user-defined function already have this attribute
MAGIC_ATTR = 'attrs_3141562937'
class Dispatch(Enum):
RANK_ZERO = 0
ONE_TO_ALL = 1
ALL_TO_ALL = 2
MEGATRON_COMPUTE = 3
MEGATRON_PP_AS_DP = 4
MEGATRON_PP_ONLY = 5
MEGATRON_COMPUTE_PROTO = 6
MEGATRON_PP_AS_DP_PROTO = 7
DP_COMPUTE = 8
DP_COMPUTE_PROTO = 9
DP_COMPUTE_PROTO_WITH_FUNC = 10
DP_COMPUTE_METRIC = 11
class Execute(Enum):
ALL = 0
RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
from verl.protocol import DataProto, DataProtoFuture
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks)
return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs):
return args, kwargs
def collect_all_to_all(worker_group, output):
return output
def dispatch_megatron_compute(worker_group, *args, **kwargs):
"""
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group,
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
all_args = []
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
transformed_args.append(arg[local_dp_rank])
all_args.append(transformed_args)
all_args = tuple(all_args)
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
transformed_v.append(v[local_dp_rank])
all_kwargs[k] = transformed_v
return all_args, all_kwargs
def collect_megatron_compute(worker_group, output):
"""
Only collect the data from the tp=0 and pp=last and every dp ranks
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
pp_size = worker_group.get_megatron_global_info().pp_size
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
output_in_dp.append(output[global_rank])
return output_in_dp
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
"""
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
def _concat_data_proto_or_future(output: List):
from verl.protocol import DataProto, DataProtoFuture
import ray
# make sure all the elements in output has the same type
for o in output:
assert type(o) == type(output[0])
o = output[0]
if isinstance(o, DataProto):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
else:
raise NotImplementedError
def collect_megatron_compute_data_proto(worker_group, output):
"""
Each output must be a DataProto. We concat the dim=0 of output
"""
from verl.protocol import DataProto
import ray
output = collect_megatron_compute(worker_group, output)
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
return _concat_data_proto_or_future(output)
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
"""
treat pp as dp.
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_size = worker_group.pp_size
dp_size = worker_group.dp_size
pp_dp_size = pp_size * dp_size
all_args = []
for arg in args:
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
# compute the rank in arg. Note that the order is dp then pp
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
# dispatch: pp_allgther: collect:
# dp 0 1 2 3 dp 0 1 2 3
# pp +---------+ pp +-------------+
# 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
# 1 | B D F H | 1 | AB CD EF GH |
# +---------+ +-------------+
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
transformed_args.append(arg[arg_rank])
all_args.append(transformed_args)
all_args = tuple(all_args)
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_size, f'expect len(v)=={pp_dp_size}, got {len(v)}'
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
# compute the rank in arg. Note that the order is dp then pp
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
transformed_v.append(v[arg_rank])
all_kwargs[k] = transformed_v
return all_args, all_kwargs
def collect_megatron_pp_as_dp(worker_group, output):
"""
treat pp as dp. Only collect data on tp=0
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
output_in_dp.append(output[global_rank])
return output_in_dp
def collect_megatron_pp_only(worker_group, output):
"""
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_pp = []
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
output_in_pp.append(output[global_rank])
return output_in_pp
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_dp_size = worker_group.dp_size * worker_group.pp_size
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
from verl.protocol import DataProto
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output = collect_megatron_pp_as_dp(worker_group, output)
return _concat_data_proto_or_future(output)
def dispatch_dp_compute(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
return args, kwargs
def collect_dp_compute(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output):
from verl.protocol import DataProto
import ray
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
def get_predefined_dispatch_fn(dispatch_mode):
predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: {
'dispatch_fn': dispatch_one_to_all,
'collect_fn': collect_all_to_all,
},
Dispatch.ALL_TO_ALL: {
'dispatch_fn': dispatch_all_to_all,
'collect_fn': collect_all_to_all,
},
Dispatch.MEGATRON_COMPUTE: {
'dispatch_fn': dispatch_megatron_compute,
'collect_fn': collect_megatron_compute,
},
Dispatch.MEGATRON_PP_AS_DP: {
'dispatch_fn': dispatch_megatron_pp_as_dp,
'collect_fn': collect_megatron_pp_as_dp,
},
Dispatch.MEGATRON_PP_ONLY: {
'dispatch_fn': dispatch_one_to_all,
'collect_fn': collect_megatron_pp_only
},
Dispatch.MEGATRON_COMPUTE_PROTO: {
'dispatch_fn': dispatch_megatron_compute_data_proto,
'collect_fn': collect_megatron_compute_data_proto
},
Dispatch.MEGATRON_PP_AS_DP_PROTO: {
'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
'collect_fn': collect_megatron_pp_as_dp_data_proto
},
Dispatch.DP_COMPUTE: {
'dispatch_fn': dispatch_dp_compute,
'collect_fn': collect_dp_compute
},
Dispatch.DP_COMPUTE_PROTO: {
'dispatch_fn': dispatch_dp_compute_data_proto,
'collect_fn': collect_dp_compute_data_proto
},
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
'collect_fn': collect_dp_compute_data_proto
},
Dispatch.DP_COMPUTE_METRIC: {
'dispatch_fn': dispatch_dp_compute_data_proto,
'collect_fn': collect_dp_compute
}
}
return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode):
"""
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
"""
predefined_execute_mode_fn = {
Execute.ALL: {
'execute_fn_name': 'execute_all'
},
Execute.RANK_ZERO: {
'execute_fn_name': 'execute_rank_zero'
}
}
return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode):
assert isinstance(dispatch_mode,
(Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}'
if isinstance(dispatch_mode, Dict):
necessary_keys = ['dispatch_fn', 'collect_fn']
for key in necessary_keys:
assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary'
def _check_execute_mode(execute_mode):
assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}'
def _materialize_futures(*args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, DataProtoFuture):
arg = arg.get()
# add more type to materialize
new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture):
kwargs[k] = v.get()
new_args = tuple(new_args)
return new_args, kwargs
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
_check_dispatch_mode(dispatch_mode=dispatch_mode)
_check_execute_mode(execute_mode=execute_mode)
def decorator(func):
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
args, kwargs = _materialize_futures(*args, **kwargs)
return func(*args, **kwargs)
attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
setattr(inner, MAGIC_ATTR, attrs)
return inner
return decorator

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,39 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
class MegatronWorker(Worker):
def __init__(self, cuda_visible_devices=None) -> None:
super().__init__(cuda_visible_devices)
def get_megatron_global_info(self):
from megatron.core import parallel_state as mpu
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
return info
def get_megatron_rank_info(self):
from megatron.core import parallel_state as mpu
tp_rank = mpu.get_tensor_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
return info

View File

@@ -0,0 +1,51 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from .worker import DistRankInfo, DistGlobalInfo
from verl.single_controller.base import ResourcePool, WorkerGroup
class MegatronWorkerGroup(WorkerGroup):
def __init__(self, resource_pool: ResourcePool, **kwargs):
super().__init__(resource_pool=resource_pool, **kwargs)
self._megatron_rank_info = None
self._megatron_global_info: DistGlobalInfo = None
def init_megatron(self, default_megatron_kwargs: Dict = None):
raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
return self._megatron_rank_info[rank]
@property
def tp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.tp_size
@property
def dp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.dp_size
@property
def pp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.pp_size
def get_megatron_global_info(self):
return self._megatron_global_info

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,29 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
@ray.remote
class WorkerGroupRegisterCenter:
def __init__(self, rank_zero_info):
self.rank_zero_info = rank_zero_info
def get_rank_zero_info(self):
return self.rank_zero_info
def create_worker_group_register_center(name, info):
return WorkerGroupRegisterCenter.options(name=name).remote(info)

View File

@@ -0,0 +1,186 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
the class for Worker
"""
import os
import socket
from dataclasses import dataclass
from verl.single_controller.base.decorator import register, Dispatch, Execute
@dataclass
class DistRankInfo:
tp_rank: int
dp_rank: int
pp_rank: int
@dataclass
class DistGlobalInfo:
tp_size: int
dp_size: int
pp_size: int
class WorkerHelper:
def _get_node_ip(self):
def get_node_ip_by_sdk():
if os.getenv("WG_BACKEND", None) == "ray":
import ray
return ray._private.services.get_node_ip_address()
elif os.getenv("WG_BACKEND", None) == "torch_rpc":
from verl.single_controller.torchrpc.k8s_client import get_ip_addr
return get_ip_addr()
return None
host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6
host_ip_by_sdk = get_node_ip_by_sdk()
host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip
def _get_free_port(self):
with socket.socket() as sock:
sock.bind(('', 0))
return sock.getsockname()[1]
def get_availale_master_addr_port(self):
return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self):
return
class WorkerMeta:
keys = [
"WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"
]
def __init__(self, store) -> None:
self._store = store
def to_dict(self):
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
# note that here we use int to distinguish
disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0))
if disable_worker_init:
return instance
rank = os.environ.get("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__:
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
return instance
def _configure_before_init(self, register_center_name: str, rank: int):
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
if rank == 0:
master_addr, master_port = self.get_availale_master_addr_port()
rank_zero_info = {
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
}
if os.getenv("WG_BACKEND", None) == "ray":
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
self.register_center = create_worker_group_register_center(name=register_center_name,
info=rank_zero_info)
os.environ.update(rank_zero_info)
def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
import os
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
self._rank = rank
self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
store = {
'_world_size': world_size,
'_rank': rank,
'_local_world_size': local_world_size,
'_local_rank': local_rank,
'_master_addr': master_addr,
'_master_port': master_port
}
if cuda_visible_devices is not None:
store['_cuda_visible_devices'] = cuda_visible_devices
meta = WorkerMeta(store=store)
self._configure_with_meta(meta=meta)
def _configure_with_meta(self, meta: WorkerMeta):
"""
This function should only be called inside by WorkerGroup
"""
assert isinstance(meta, WorkerMeta)
self.__dict__.update(meta.to_dict()) # this is hacky
# print(f"__dict__: {self.__dict__}")
for key in WorkerMeta.keys:
val = self.__dict__.get(f"_{key.lower()}", None)
if val is not None:
# print(f"set {key} to {val}")
os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace(
"]", "") if self._master_addr else ""
def get_master_addr_port(self):
return self._master_addr, self._master_port
def get_cuda_visible_devices(self):
import os
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
def execute_with_func_generator(self, func, *args, **kwargs):
ret_proto = func(self, *args, **kwargs)
return ret_proto
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
def execute_func_rank_zero(self, func, *args, **kwargs):
result = func(*args, **kwargs)
return result

View File

@@ -0,0 +1,196 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
the class of WorkerGroup
"""
import logging
import threading
import signal
import time
from typing import List, Any, Callable, Dict
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
if process_on_nodes is None:
process_on_nodes = []
self._store = process_on_nodes
self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
def add_node(self, process_count):
self._store.append(process_count)
@property
def world_size(self):
return sum(self._store)
def __call__(self) -> Any:
return self._store
@property
def store(self):
return self._store
def local_world_size_list(self) -> List[int]:
nested_local_world_size_list = [
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
]
return [item for row in nested_local_world_size_list for item in row]
def local_rank_list(self) -> List[int]:
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
return [item for row in nested_local_rank_list for item in row]
class ClassWithInitArgs:
"""
This class stores a class constructor and the args/kwargs to construct the class.
It is used to instantiate the remote class.
"""
def __init__(self, cls, *args, **kwargs) -> None:
self.cls = cls
self.args = args
self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True:
for worker in workers:
if not is_alive(worker):
logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time)
class WorkerGroup:
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False
if resource_pool is not None:
# handle the case when WorkGroup is attached to an existing one
self._procecss_dispatch_config = resource_pool()
else:
self._procecss_dispatch_config = None
self._workers = []
self._worker_names = []
self._master_addr = None
self._master_port = None
self._checker_thread: threading.Thread = None
def _is_worker_alive(self, worker):
raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
def _block_until_all_workers_alive(self) -> None:
while True:
all_state = [self._is_worker_alive(worker) for worker in self._workers]
if False in all_state:
time.sleep(1)
else:
break
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
# before starting checking worker aliveness, make sure all workers are already alive
self._block_until_all_workers_alive()
self._checker_thread = threading.Thread(target=check_workers_alive,
args=(self._workers, self._is_worker_alive, every_n_seconds))
self._checker_thread.start()
@property
def world_size(self):
return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator):
"""
Bind the worker method to the WorkerGroup
"""
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
except Exception as e:
# if it is a property, it will fail because Class doesn't have instance property
continue
if hasattr(method, MAGIC_ATTR):
# this method is decorated by register
attribute = getattr(method, MAGIC_ATTR)
assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}'
assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key'
dispatch_mode = attribute['dispatch_mode']
execute_mode = attribute['execute_mode']
blocking = attribute['blocking']
# get dispatch fn
if isinstance(dispatch_mode, Dispatch):
# get default dispatch fn
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
dispatch_fn = fn['dispatch_fn']
collect_fn = fn['collect_fn']
else:
assert isinstance(dispatch_mode, dict)
assert 'dispatch_fn' in dispatch_mode
assert 'collect_fn' in dispatch_mode
dispatch_fn = dispatch_mode['dispatch_fn']
collect_fn = dispatch_mode['collect_fn']
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode['execute_fn_name']
# get execute_fn from string
try:
execute_fn = getattr(self, wg_execute_fn_name)
assert callable(execute_fn), 'execute_fn must be callable'
except Exception as e:
print(f'execute_fn {wg_execute_fn_name} is invalid')
raise
# bind a new method to the RayWorkerGroup
func = func_generator(self,
method_name,
dispatch_fn=dispatch_fn,
collect_fn=collect_fn,
execute_fn=execute_fn,
blocking=blocking)
try:
setattr(self, method_name, func)
except Exception as e:
raise ValueError(f'Fail to set method_name {method_name}')

View File

@@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from .megatron import (MegatronRayWorkerGroup, DistRankInfo, DistGlobalInfo)

View File

@@ -0,0 +1,459 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Dict, List, Any, Tuple
import ray
from ray.util import list_named_actors
from ray.util.placement_group import placement_group, PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
from ray.experimental.state.api import get_actor
from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
__all__ = ['Worker']
def get_random_string(length: int) -> str:
import random
import string
letters_digits = string.ascii_letters + string.digits
return ''.join(random.choice(letters_digits) for _ in range(length))
def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
def func(*args, **kwargs):
args, kwargs = dispatch_fn(self, *args, **kwargs)
output = execute_fn(method_name, *args, **kwargs)
if blocking:
output = ray.get(output)
output = collect_fn(self, output)
return output
return func
class RayResourcePool(ResourcePool):
def __init__(self,
process_on_nodes: List[int] = None,
use_gpu: bool = True,
name_prefix: str = "",
max_colocate_count: int = 5,
detached=False) -> None:
super().__init__(process_on_nodes, max_colocate_count)
self.use_gpu = use_gpu
# print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}")
self.name_prefix = name_prefix
self.pgs = None
self.detached = detached
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
if self.pgs is not None:
return self.pgs
pg_name_prefix = name if name else \
f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
# print(f"pg_name_prefix = {pg_name_prefix}")
pg_scheme = [[{
"CPU": self.max_collocate_count,
"GPU": 1
} if self.use_gpu else {
"CPU": self.max_collocate_count
} for _ in range(process_count)] for process_count in self._store]
lifetime = 'detached' if self.detached else None
pgs = [
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
for idx, bundles in enumerate(pg_scheme)
]
ray.get([pg.ready() for pg in pgs])
self.pgs = pgs
return pgs
def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str],
resource_pool: RayResourcePool) -> List:
src_pgs = [
pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups()
if role_name in src_role_names
]
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
searching_idx = 0
for request_process, original_idx in sorted_process_on_nodes:
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
assert request_process <= sorted_src_pgs[searching_idx].bundle_count, \
f"requesting {request_process} processes, bundle count cannot satisfy"
unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))
searching_idx += 1
return [pg for _, pg in sorted(unsorted_pgs)]
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
assert rp1.use_gpu == rp2.use_gpu, 'Both RayResourcePool must either use_gpu or not'
assert rp1.max_collocate_count == rp2.max_collocate_count, 'Both RayResourcePool must has the same max_collocate_count'
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, 'Both RayResourcePool must has the same n_gpus_per_node'
assert rp1.detached == rp2.detached, 'Detached ResourcePool cannot be merged with non-detached ResourcePool'
new_store = rp1.store + rp2.store
merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}")
merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
return merged
class RayClassWithInitArgs(ClassWithInitArgs):
def __init__(self, cls, *args, **kwargs) -> None:
# self._options = kwargs.pop('options', dict())
super().__init__(cls, *args, **kwargs)
self._options = {}
self._additional_resource = {}
def set_additional_resource(self, additional_resource):
self._additional_resource = additional_resource
def update_options(self, options: Dict):
self._options.update(options)
def __call__(self,
placement_group,
placement_group_bundle_idx,
use_gpu: bool = True,
num_gpus=1,
sharing_with=None) -> Any:
if sharing_with is not None:
target_node_id = ray.get(sharing_with.get_node_id.remote())
cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
return self.cls.options(**options).remote(*self.args,
cuda_visible_devices=cuda_visible_devices,
**self.kwargs)
options = {
"scheduling_strategy":
PlacementGroupSchedulingStrategy(placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_idx)
}
options.update(self._options)
if use_gpu:
options["num_gpus"] = num_gpus
if len(self._additional_resource) > 1:
for k, v in self._additional_resource.items():
options[k] = v
# print("cls:", self.cls)
# print("args: ", self.args)
# print("kwargs: ", self.kwargs)
return self.cls.options(**options).remote(*self.args, **self.kwargs)
class RayWorkerGroup(WorkerGroup):
def __init__(self,
resource_pool: RayResourcePool = None,
ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True,
name_prefix: str = None,
detached=False,
worker_names=None,
**kwargs) -> None:
super().__init__(resource_pool=resource_pool, **kwargs)
self.ray_cls_with_init = ray_cls_with_init
self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
if worker_names is not None:
assert self._is_init_with_detached_workers
self._worker_names = worker_names
if self._is_init_with_detached_workers:
self._init_with_detached_workers(worker_names=worker_names)
else:
self._init_with_resource_pool(resource_pool=resource_pool,
ray_cls_with_init=ray_cls_with_init,
bin_pack=bin_pack,
detached=detached)
if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
def _is_worker_alive(self, worker: ray.actor.ActorHandle):
worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
def _init_with_detached_workers(self, worker_names):
workers = [ray.get_actor(name=name) for name in worker_names]
self._workers = workers
self._world_size = len(worker_names)
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
use_gpu = resource_pool.use_gpu
strategy = "PACK"
if bin_pack:
strategy = "STRICT_PACK"
pgs = resource_pool.get_placement_groups(strategy=strategy)
world_size = resource_pool.world_size
self._world_size = world_size
# cia.add_kwarg("_world_size", world_size)
num_gpus = 1 / resource_pool.max_collocate_count
rank = -1
for pg_idx, local_world_size in enumerate(resource_pool.store):
pg = pgs[pg_idx]
assert local_world_size <= pg.bundle_count, \
f"when generating for {self.name_prefix}, for the "
for local_rank in range(local_world_size):
rank += 1
# we pass in environment variable at option so that Worker can use environment variable to set
env_vars = {
'WORLD_SIZE': str(world_size),
'RANK': str(rank),
'WG_PREFIX': self.name_prefix,
'WG_BACKEND': 'ray',
'RAY_LOCAL_WORLD_SIZE': str(local_world_size),
'RAY_LOCAL_RANK': str(local_rank),
}
if rank != 0:
env_vars['MASTER_ADDR'] = self._master_addr
env_vars['MASTER_PORT'] = self._master_port
import re
cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name})
if detached:
ray_cls_with_init.update_options({'lifetime': 'detached'})
# create a worker
worker = ray_cls_with_init(placement_group=pg,
placement_group_bundle_idx=local_rank,
use_gpu=use_gpu,
num_gpus=num_gpus)
self._workers.append(worker)
self._worker_names.append(name)
if rank == 0:
register_center_actor = None
for _ in range(120):
if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1)
else:
register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
break
assert register_center_actor is not None, f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT']
# print(f"rank_zero_info: {rank_zero_info}")
# print(f"master_addr: {self._master_addr}, master_port: {self._master_port}")
@property
def worker_names(self):
return self._worker_names
@classmethod
def from_detached(cls, worker_names=None, ray_cls_with_init=None):
worker_group = cls(resource_pool=None,
ray_cls_with_init=ray_cls_with_init,
name_prefix=None,
worker_names=worker_names)
return worker_group
def spawn(self, prefix_set):
"""
spawn to a dictionary of worker groups, each with a subset of method with prefix.
"""
def _rebind_actor_methods(worker_group, actor_name):
"""
bind the method with actor_prefix to its original name
"""
prefix: str = actor_name + '_'
for method_name in dir(worker_group):
if method_name.startswith(prefix):
# only valid when Python >= 3.9
original_method_name = method_name.removeprefix(prefix)
method = getattr(worker_group, method_name)
setattr(worker_group, original_method_name, method)
new_worker_group_dict = {}
for prefix in prefix_set:
new_worker_group = self.from_detached(worker_names=self._worker_names,
ray_cls_with_init=self.ray_cls_with_init)
_rebind_actor_methods(new_worker_group, prefix)
new_worker_group_dict[prefix] = new_worker_group
return new_worker_group_dict
def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
return ray.get(self.execute_all_async(method_name, **args, **kwargs))
def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
remote_call = getattr(self._workers[0], method_name)
return remote_call.remote(*args, **kwargs)
def execute_rank_zero(self, method_name: str, *args, **kwargs):
return self.execute_rank_zero_async(method_name, *args, **kwargs)
def execute_all(self, method_name: str, *args, **kwargs):
return self.execute_all_async(method_name, *args, **kwargs)
def execute_all_sync(self, method_name: str, *args, **kwargs):
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
def execute_all_async(self, method_name: str, *args, **kwargs):
# 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list且所有的 list 长度都与 len(self._workers) 一致的话,我们会把
# list 中的每一个分别发到对应的 worker 上去
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
length = len(self._workers)
if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
# print(f"splitting args and kwargs into {length} shards")
result = []
for i in range(length):
sliced_args = tuple(arg[i] for arg in args)
sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
remote_call = getattr(self._workers[i], method_name)
result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
return result
return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
@property
def master_address(self):
return self._master_addr
@property
def master_port(self):
return self._master_port
@property
def workers(self):
return self._workers
@property
def world_size(self):
return self._world_size
"""
Utilities that enables creating workers inside the same ray.Actor,
with code written in separate ray.Actors.
"""
from unittest.mock import patch
from verl.single_controller.base.decorator import MAGIC_ATTR
import os
def _bind_workers_method_to_parent(cls, key, user_defined_cls):
"""
Binds the methods of each worker to the WorkerDict.
Note that we only bind public methods that are decorated by register
"""
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
except Exception as e:
# if it is a property, it will fail because Class doesn't have instance property
continue
if hasattr(method, MAGIC_ATTR):
def generate_function(name):
def func(self, *args, **kwargs):
# dispatch to the actual worker
return getattr(self.worker_dict[key], name)(*args, **kwargs)
return func
func = generate_function(method_name)
# pass MAGIC_ATTR for outer worker group
setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
try:
method_name_with_prefix = key + '_' + method_name
setattr(cls, method_name_with_prefix, func)
# print(f'Binding {method_name_with_prefix}')
except Exception as e:
raise ValueError(f'Fail to set method_name {method_name}')
def _unwrap_ray_remote(cls):
if hasattr(cls, '__ray_actor_class__'):
cls = cls.__ray_actor_class__
return cls
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
"""
This function should return a class instance that delegates the calls to every
cls in cls_dict
"""
cls_dict = {}
init_args_dict = {}
worker_cls = None
for key, cls in class_dict.items():
if worker_cls == None:
worker_cls = cls.cls.__ray_actor_class__.__base__
else:
assert worker_cls == cls.cls.__ray_actor_class__.__base__, \
'the worker class should be the same when share the same process'
cls_dict[key] = cls.cls
init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs}
assert cls_dict.keys() == init_args_dict.keys()
# TODO: create a class with customizable name
class WorkerDict(worker_cls):
def __init__(self):
super().__init__()
self.worker_dict = {}
for key, user_defined_cls in cls_dict.items():
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
# directly instantiate the class without remote
with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
**init_args_dict[key].get('kwargs', {}))
# now monkey-patch the methods from inner class to WorkerDict
for key, user_defined_cls in cls_dict.items():
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
remote_cls = ray.remote(WorkerDict)
remote_cls = RayClassWithInitArgs(cls=remote_cls)
return remote_cls

View File

@@ -0,0 +1,62 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import ray
from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
# NOTE(sgm): for opensource megatron-core
class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
"""
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
so that the dispatcher can use it to dispatch data.
"""
def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
self._megatron_global_info: DistGlobalInfo = ray.get(
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
"""
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
so that the dispatcher can use it to dispatch data.
"""
def __init__(self,
resource_pool: RayResourcePool,
ray_cls_with_init: RayClassWithInitArgs,
default_megatron_kwargs: Dict = None,
**kwargs):
super().__init__(resource_pool=resource_pool,
ray_cls_with_init=ray_cls_with_init,
default_megatron_kwargs=default_megatron_kwargs,
**kwargs)
self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
self._megatron_global_info: DistGlobalInfo = ray.get(
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
# after super, we will call init of each worker
if not self._is_init_with_detached_workers:
# only init_megatron if the WorkerGroup is created from scratch
self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs)

View File

@@ -0,0 +1 @@
0.0.2

13
verl/third_party/__init__.py vendored Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

51
verl/third_party/vllm/__init__.py vendored Normal file
View File

@@ -0,0 +1,51 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import version, PackageNotFoundError
def get_version(pkg):
try:
return version(pkg)
except PackageNotFoundError:
return None
package_name = 'vllm'
package_version = get_version(package_name)
if package_version == '0.3.1':
vllm_version = '0.3.1'
from .vllm_v_0_3_1.llm import LLM
from .vllm_v_0_3_1.llm import LLMEngine
from .vllm_v_0_3_1 import parallel_state
elif package_version == '0.4.2':
vllm_version = '0.4.2'
from .vllm_v_0_4_2.llm import LLM
from .vllm_v_0_4_2.llm import LLMEngine
from .vllm_v_0_4_2 import parallel_state
elif package_version == '0.5.4':
vllm_version = '0.5.4'
from .vllm_v_0_5_4.llm import LLM
from .vllm_v_0_5_4.llm import LLMEngine
from .vllm_v_0_5_4 import parallel_state
elif package_version == '0.6.3':
vllm_version = '0.6.3'
from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state
else:
raise ValueError(
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.'
)

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,228 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
from transformers import PretrainedConfig
from .config import ModelConfig
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model_hf_config: PretrainedConfig = None
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
load_format: str = 'model'
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
# TODO(shengguangming): delete the unused args
parser.add_argument('--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
parser.add_argument('--dtype',
type=str,
default=EngineArgs.dtype,
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.load_format, self.revision,
self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray,
self.max_parallel_loading_workers, self.disable_custom_all_reduce)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len,
self.max_paddings)
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
None) if self.enable_lora else None
return (model_config, cache_config, parallel_config, scheduler_config, device_config, lora_config)
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', help='disable logging requests')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='max number of prompt characters or prompt '
'ID numbers being printed in log. '
'Default: unlimited.')
return parser

View File

@@ -0,0 +1,577 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
from typing import Optional, Union, ClassVar
from dataclasses import dataclass
import torch
from transformers import PretrainedConfig
from packaging.version import Version
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
logger = init_logger(__name__)
_GB = 1 << 30
class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
self,
hf_config: PretrainedConfig,
dtype: str,
seed: int,
load_format: str = 'model',
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
trust_remote_code: Optional[bool] = True,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
) -> None:
self.model = hf_config._name_or_path
self.tokenizer = hf_config._name_or_path
self.load_format = load_format
self.seed = seed
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.trust_remote_code = trust_remote_code
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
# self.hf_config = get_config(model, trust_remote_code, revision)
self.hf_config = hf_config
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
# self._verify_load_format()
# self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in ["auto", "pt", "safetensors", "npcache", "dummy", "model"]:
raise ValueError(f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', 'dummy' or 'model'.")
self.load_format = load_format
# def _verify_tokenizer_mode(self) -> None:
# tokenizer_mode = self.tokenizer_mode.lower()
# if tokenizer_mode not in ["auto", "slow"]:
# raise ValueError(
# f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
# "either 'auto' or 'slow'.")
# self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq", "gptq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
raise ValueError("Quantization method specified in the model config "
f"({hf_quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip() and self.quantization in rocm_not_supported_quantization:
raise ValueError(f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len)
if (self.quantization in ["gptq", "squeezellm"] and not self.enforce_eager):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger.warning(f"{self.quantization} does not support CUDA graph "
"yet. Disabling CUDA graph.")
self.enforce_eager = True
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (self.hf_config.model_type in falcon_model_types and
getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
class CacheConfig:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
"""
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self._verify_args()
self._verify_cache_dtype()
# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0:
raise ValueError("GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
raise ValueError("FP8 is not supported when cuda version is lower than 11.8.")
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError("FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info("Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
"""
def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError("Pipeline parallelism is not supported yet.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info("Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info("Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
if not self.disable_custom_all_reduce and self.world_size > 1:
self.disable_custom_all_reduce = True
logger.info("Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved.")
class SchedulerConfig:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args()
def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")
class DeviceConfig:
def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}.")
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
f"must be one of {possible_lora_extra_vocab_size}.")
if self.max_loras < 1:
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})")
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError("LoRA is not supported with quantized models yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
raise ValueError("Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.")
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len
default_max_len = 2048
logger.warning("The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling["original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return int(max_model_len)

View File

@@ -0,0 +1,275 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from typing import Dict, List, Optional, Tuple, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PretrainedConfig
import torch.nn as nn
from .arg_utils import EngineArgs
from .llm_engine_sp import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: A HuggingFace Transformers model instance.
tokenizer: A HuggingFace Transformers tokenizer instance.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
model_hf_config: PretrainedConfig,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model_hf_config=model_hf_config,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
def init_cache_engine(self):
self.llm_engine.init_cache_engine()
def free_cache_engine(self):
self.llm_engine.free_cache_engine()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None and prompt_token_ids is not None:
if len(prompts) != len(prompt_token_ids):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
if not isinstance(token_ids, list):
# NOTE(shengguangming): convert the rollout input into List[str]
token_ids = self._pre_process_inputs(token_ids)
self._add_request(prompt, sampling_params, token_ids, lora_request=lora_request, prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
# TODO(shengguangming): maybe we can hack the autoregressive logics without only apply post process for better performance
return self._post_process_outputs(outputs)
# NOTE(shengguangming): add for verl
# TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
# NOTE(shengguangming): add for verl
def _post_process_outputs(self, outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
output_token_ids = []
logprobs = []
for output in outputs: # List[RequestOutput]
output = output.outputs
for output in output: # List[CompletionOutput], usually len == 1
output_token_ids.append(torch.tensor(output.token_ids))
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
logprobs_dicts = output.logprobs
if logprobs_dicts is not None:
logprob = []
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
logprob.append(logprobs_dict[id])
logprobs.append(torch.tensor(logprob))
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
if len(logprobs) > 0:
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
return output_token_ids, logprobs
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -> None:
self.llm_engine.sync_model_weights(actor_weights=actor_weights)
def offload_model_weights(self) -> None:
self.llm_engine.offload_model_weights()

View File

@@ -0,0 +1,765 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
import os
import socket
import time
import torch
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.engine.metrics import StatLogger, Stats
from vllm.utils import Counter
import torch.nn as nn
from .arg_utils import EngineArgs
from .tokenizer import TokenizerGroup
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: nn.Module,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
distributed_init_method: str,
placement_group: Optional[None],
log_stats: bool,
) -> None:
logger.info("Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
# f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
# f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
# f"download_dir={model_config.download_dir!r}, "
# f"load_format={model_config.load_format}, "
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config # TODO: currently is hfconfig
self.cache_config = cache_config
self.lora_config = lora_config
assert self.cache_config.sliding_window == getattr(self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.log_stats = log_stats
self._verify_args()
# self.model = model # should not store the model, it should be deleted
# TODO(shengguangming): maybe we can choose init here or from arguments
self._init_tokenizer(tokenizer)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self._init_workers_sp(model, distributed_init_method)
# Profile the memory usage and initialize the cache.
self._init_cache_sp()
# Create the scheduler.
# NOTE(shengguangming): each process will have independent scheduler
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
init_kwargs = dict(enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs)
# TODO: check get_lora_tokenizer func
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_model()
self.worker.load_model()
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache_sp(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
# FIXME(woosuk): Change to debug log.
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_gpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self.worker.init_cache_engine(cache_config=self.cache_config)
self.worker.warm_up_model()
def init_cache_engine(self):
self.worker.init_cache_engine(cache_config=self.cache_config)
def free_cache_engine(self):
self.worker.free_cache_engine()
@classmethod
def from_engine_args(cls, model, tokenizer, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(model,
tokenizer,
*engine_configs,
distributed_init_method,
placement_group,
log_stats=not engine_args.disable_log_stats)
return engine
def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
"""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(prompt_token_ids[:prefix_pos], lora_request.lora_int_id if
lora_request else 0) if prefix_pos is not None else None
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time, lora_request, prefix)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Args:
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
"""
self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(current_worst_seq).eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(best_running_seq.get_prompt_len() + sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id))
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token, child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
# self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq,
current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated and not seq_group.prefix.computed):
seq_group.prefix.computed = True
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty():
output = self.worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, # TODO: check this input
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,)
else:
return [RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups]
return self._process_model_outputs(output, scheduler_outputs)
def do_log_stats(self) -> None:
"""Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
"""Get Stats to be Logged to Prometheus."""
now = time.monotonic()
# KV Cache Usage in %.
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks()
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
# Scheduler State
num_running = len(self.scheduler.running)
num_swapped = len(self.scheduler.swapped)
num_waiting = len(self.scheduler.waiting)
# Iteration stats if we have scheduler output.
num_prompt_tokens = 0
num_generation_tokens = 0
time_to_first_tokens = []
time_per_output_tokens = []
time_e2e_requests = []
if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run
# Number of Tokens.
if prompt_run:
num_prompt_tokens = scheduler_outputs.num_batched_tokens
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings.
time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests.
if seq_group.is_finished():
time_e2e_requests.append(now - seq_group.arrival_time)
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
return Stats(
now=now,
num_running=num_running,
num_swapped=num_swapped,
num_waiting=num_waiting,
gpu_cache_usage=gpu_cache_usage,
cpu_cache_usage=cpu_cache_usage,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=num_generation_tokens,
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
)
# TODO: we may not need to decode
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text
def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# for stop_str in sampling_params.stop:
# if seq.output_text.endswith(stop_str):
# self._finalize_sequence(seq, sampling_params, stop_str)
# seq.status = SequenceStatus.FINISHED_STOPPED
# return
# if seq.get_last_token_id() in sampling_params.stop_token_ids:
# stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(seq.get_last_token_id())
# self._finalize_sequence(seq, sampling_params, stop_str)
# seq.status = SequenceStatus.FINISHED_STOPPED
# return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) and
seq.get_last_token_id() == self.get_tokenizer_for_seq(seq).eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _finalize_sequence(self, seq: Sequence, sampling_params: SamplingParams, stop_string: str) -> None:
if not sampling_params.include_stop_str_in_output and stop_string:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
return self.worker.list_loras()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights)
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
A tuple of (`distributed_init_method`, `placement_group`). The
`distributed_init_method` is the address for initializing the
distributed backend. `placement_group` includes the specification
of the resources for each distributed worker.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
return distributed_init_method, None
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]

View File

@@ -0,0 +1,275 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
"""Utilities for selecting and loading models."""
import contextlib
from typing import Dict, Type, Union
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from megatron.core.tensor_parallel.utils import VocabUtility
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights)
from .config import ModelConfig
from vllm.config import DeviceConfig, LoRAConfig
from .weight_loaders import *
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sequence import SamplerOutput
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import _prune_hidden_states, _apply_logits_processors, _apply_penalties, _apply_top_k_top_p, _apply_min_p, _apply_penalties, _sample, _get_logprobs, _build_sampler_output
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.activation import ScaledActivation
__LAYER_WEIGHT_LOADER_REGISTRY__ = {
ColumnParallelLinear: parallel_weight_loader,
MergedColumnParallelLinear: parallel_weight_loader,
QKVParallelLinear: parallel_weight_loader,
RowParallelLinear: parallel_weight_loader,
VocabParallelEmbedding: parallel_weight_loader,
ParallelLMHead: parallel_weight_loader
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
# "default_weight_loader": default_weight_loader
}
# NOTE(gmsheng): change the weight_loader function in runtime
for layer_class, weight_loader in __LAYER_WEIGHT_LOADER_REGISTRY__.items():
layer_class.weight_loader = weight_loader
__MODEL_WEIGHT_LOADER_REGISTRY__ = {
'GPT2LMHeadModel': gpt2_weight_loader,
'LlamaForCausalLM': llama_weight_loader,
'LLaMAForCausalLM': llama_weight_loader,
'MistralForCausalLM': mistral_weight_loader,
}
# FIXME(shengguangming): the vLLM vocab will pad to 64, which may incur out of bounds
# so we need to rewrite the init function of vocab
DEFAULT_VOCAB_PADDING_SIZE = 64
def vocab_init(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
# TODO (pad to be divided by 4)
self.num_embeddings = num_embeddings
self.org_vocab_size = org_num_embeddings or num_embeddings
# self.num_embeddings_padded = pad_vocab_size(num_embeddings,
# padding_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
# device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})
VocabParallelEmbedding.__init__ = vocab_init
def _get_model_weight_loader(arch: str):
if arch in __MODEL_WEIGHT_LOADER_REGISTRY__:
return __MODEL_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
linear_method = None
quant_config = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.hf_config,
model_config.download_dir)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
# with torch.device(device_config.device):
# NOTE(sgm): init the model in cpu
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy":
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
elif model_config.load_format == 'model' or model_config.load_format == 'auto':
# NOTE(shengguangming) Load the weights from the actor model
if isinstance(actor_model, nn.Module):
load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
else:
load_weights(actor_weights=actor_model, vllm_model=model)
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
# the actor model is .state_dict()
def load_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
# FIXME(sgm): hack the Sampler function in vllm v0.3.1
# as they use ray, the sampler result will only need to return to the driver node,
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
# all_gather is required (aligned with v0.2.6)
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# save origin logprobs for sampler_output
origin_logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None
_, vocab_size = logits.shape
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(sampling_metadata, vocab_size, logits.device, logits.dtype)
# Apply presence and frequency penalties.
if do_penalties:
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, sampling_tensors.output_tokens,
sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
# prompt_logprobs, sample_logprobs = _get_logprobs(
# logprobs, sampling_metadata, sample_results)
prompt_logprobs, sample_logprobs = _get_logprobs(origin_logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs)
from vllm.model_executor.layers.sampler import Sampler
Sampler._get_logits = _get_logits
Sampler.forward = forward

View File

@@ -0,0 +1,285 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
from typing import Dict, List, Optional, Tuple, Set, Union
import contextlib
import time
import numpy as np
import torch
import torch.nn as nn
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner, _async_h2d
from .model_loader import get_model
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None)
self.device_config = (device_config if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = model # this will be replaced by get_model()
self.block_size = None # Set after initial profiling.
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
self.max_context_len_to_capture = (self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None:
self.model = get_model(actor_model=self.model,
model_config=self.model_config,
device_config=self.device_config,
lora_config=self.lora_config)
vocab_size = self.model.config.vocab_size
if self.lora_config:
assert hasattr(
self.model,
"supported_lora_modules") and self.model.supported_lora_modules, "Model does not support LoRA"
assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[sampling_params.sampling_type].append(categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx, selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx + subquery_len - 1)
selected_token_start_idx += max_subquery_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(range(selected_token_start_idx, selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx, categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, target_device=self.device, pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, Set[int], LoRAMapping]:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping,
lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
subquery_lens = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens)
if self.lora_config:
flat_lora_index_mapping = [item for sublist in lora_index_mapping for item in sublist]
lora_mapping = LoRAMapping(
flat_lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
return (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model.
if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
)
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
return output
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model_config.get_vocab_size()
# FIXME(sgm): this sampling params will call cumsum(), causing the
# deterministic cumsum throw error
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers
self.execute_model(seqs, kv_caches)
torch.cuda.synchronize()
return

View File

@@ -0,0 +1,147 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import torch
import torch.distributed
import vllm.model_executor.parallel_utils.parallel_state as ps
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
_MICRO_DATA_PARALLEL_GROUP = None
def initialize_model_parallel_from_megatron(
tensor_model_parallel_size=None # we set None for backward compatibility to set infer_tp = train_tp
) -> None:
from megatron.core import parallel_state as mpu
from megatron.distributed import new_group
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
if tensor_model_parallel_size is None:
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
else:
assert isinstance(tensor_model_parallel_size, int)
# Build the tensor model-parallel groups.
assert ps._TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")
assert tensor_model_parallel_size <= mpu.get_tensor_model_parallel_world_size(
), 'Not implemented for infer_tp > train_tp'
global _TENSOR_MODEL_PARALLEL_GROUP
global _MICRO_DATA_PARALLEL_GROUP
assert mpu.get_tensor_model_parallel_world_size() % tensor_model_parallel_size == 0
micro_dp_size = mpu.get_tensor_model_parallel_world_size() // tensor_model_parallel_size
world_size: int = torch.distributed.get_world_size()
num_micro_dp_groups = world_size // micro_dp_size
rank = torch.distributed.get_rank()
# Build the micro dp groups.
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
for i in range(num_micro_dp_groups):
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
group = new_group(rank=rank, ranks=ranks, group_type='micro_dp')
if rank in ranks:
_MICRO_DATA_PARALLEL_GROUP = group
if tensor_model_parallel_size == mpu.get_tensor_model_parallel_world_size():
# using the same tp group as Megatron
ps._TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()
_TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()
# no _MICRO_DATA_PARALLEL_GROUP
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
train_tp = mpu.get_tensor_model_parallel_world_size()
num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
assert _TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
# group = torch.distributed.new_group(ranks)
group = new_group(rank=rank, ranks=ranks, group_type='infer_tp')
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
ps._TENSOR_MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ("tensor model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
"""
Micro Data parallel group
"""
def get_micro_data_parallel_group():
assert _MICRO_DATA_PARALLEL_GROUP is not None
return _MICRO_DATA_PARALLEL_GROUP
def get_micro_data_parallel_world_size():
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
def get_micro_data_parallel_rank():
return torch.distributed.get_rank(group=get_micro_data_parallel_group())

View File

@@ -0,0 +1,72 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
# TODO(sgm): the lora tokenizer is also passed, but may be different
tokenizer = self.tokenizer
# tokenizer = (get_lora_tokenizer(
# lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

View File

@@ -0,0 +1,95 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict
import torch
import torch.nn as nn
# NOTE(shengguangming): replace the origin weight loader function in the class
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Parallel Linear weight loader."""
assert param.size() == loaded_weight.size(
), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
param.size(), loaded_weight.size())
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
# TODO: check megatron
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# NOTE(shengguangming): the megatron llama may have this prefix
prefix = '0.module.module.'
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if name[:len(prefix)] == prefix:
name = name[len(prefix):]
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# TODO: need to implement a general way to deal with prefix
prefix = '0.module.module.'
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if name[:len(prefix)] == prefix:
name = name[len(prefix):]
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,314 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import os
import gc
from typing import Dict, List, Tuple, Optional, Union, Set
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (initialize_model_parallel)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_group
from .model_runner import ModelRunner
from .model_loader import load_weights
from .parallel_state import initialize_model_parallel_from_megatron
from vllm.lora.request import LoRARequest
class Worker:
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
rank: Optional[int] = None,
distributed_init_method: Optional[str] = None,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.model_runner = ModelRunner(
model,
model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.sliding_window = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
# For offloading inference engine params
self.cpu_model = None
def init_model(self, cupy_port: Optional[int] = None):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment.
# TODO: do not use cupy
_init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar()
# Initialize the model.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
def load_model(self):
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(block_size, cache_dtype, self.model_config,
self.parallel_config)
# NOTE(sgm) use the remaining memory
num_gpu_blocks = int((free_gpu_memory * gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
# Synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group())
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
if self.cache_engine is None and self.gpu_cache is None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
def warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if cache_events is not None:
for event in cache_events:
event.wait()
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
num_seq_groups = len(seq_group_metadata_list)
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache)
return output
# # Prepare input tensors.
# # NOTE(shengguangming): currently we pad in our dataloader and unpad it in pre_process_input, j
# # we can just input un-padded sequence for better performance
# input_tokens, input_positions, input_metadata = self._prepare_inputs(seq_group_metadata_list)
# # Execute the model.
# output = self.model(
# input_ids=input_tokens,
# positions=input_positions,
# kv_caches=self.gpu_cache,
# input_metadata=input_metadata,
# cache_events=cache_events,
# )
# return output
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict):
load_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device='cpu')
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
print('The distributed environment has been initialized before vLLM')
elif not distributed_init_method:
raise ValueError("distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
# init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# TODO (shengguangming): maybe we should also flag the megatron is initialized
if torch.distributed.get_world_size() > 1:
initialize_model_parallel_from_megatron(tensor_model_parallel_size=parallel_config.tensor_parallel_size)
else:
initialize_model_parallel()
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
return x + [pad] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError("Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,320 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
import os
import argparse
import dataclasses
from dataclasses import dataclass
from typing import List, Optional, Union
import torch.nn as nn
from transformers import PretrainedConfig
from .config import ModelConfig, LoadConfig
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model_hf_config: PretrainedConfig = None
skip_tokenizer_init: bool = False
served_model_name: Optional[Union[str, List[str]]] = None # TODO
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
# TODO(shengguangming): delete the unused args
parser.add_argument('--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
parser.add_argument('--dtype',
type=str,
default=EngineArgs.dtype,
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_config(
self,
) -> EngineConfig:
device_config = DeviceConfig(self.device)
# NOTE(sgm): we only modify ModelConfig, other configs are import from vllm
model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.num_gpu_blocks_override,
model_config.get_sliding_window(), self.enable_prefix_caching)
parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray,
self.max_parallel_loading_workers, self.disable_custom_all_reduce,
TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
parallel_config.world_size = world_size
# TODO: spec config
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
None) if self.enable_lora else None
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
)
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size):
raise ValueError('Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
)
else:
vision_language_config = None
decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend)
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config)

View File

@@ -0,0 +1,200 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
import enum
import json
from typing import List, Optional, Union
from dataclasses import dataclass, field, fields
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils import is_hip
# Add for verl
from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
logger = init_logger(__name__)
_GB = 1 << 30
class ModelConfig(ModelConfig):
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
def __init__(
self,
hf_config: PretrainedConfig,
dtype: str,
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
self.model = hf_config._name_or_path
self.tokenizer = hf_config._name_or_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init
# self.hf_config = get_config(model, trust_remote_code, revision)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
# TODO: for multimodal model
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
# self.served_model_name = get_served_model_name(model,
# served_model_name)
# self._verify_load_format()
# self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
class LoadFormat(str, enum.Enum):
AUTO = 'auto'
MEGATRON = "megatron"
HF = "hf"
DTENSOR = 'dtensor'
DUMMY_HF = 'dummy_hf'
DUMMY_MEGATRON = 'dummy_megatron'
DUMMY_DTENSOR = 'dummy_dtensor'
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
]
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")

View File

@@ -0,0 +1,269 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict, Iterable, Tuple
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor, Shard, Replicate
from vllm.model_executor.layers.linear import *
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
norm_weight = local_loaded_weight + 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, norm_weight.to(dtype=param.dtype))
else:
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
pass
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert param_name in parallelize_plan.keys(), \
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
placements=placement).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split('.')
# Reconstruct the string without 'model.layers.x.'
name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
'GPT2LMHeadModel': gpt2_dtensor_weight_loader,
'LlamaForCausalLM': llama_dtensor_weight_loader,
'LLaMAForCausalLM': llama_dtensor_weight_loader,
'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
'InternLMForCausalLM': llama_dtensor_weight_loader,
'AquilaModel': llama_dtensor_weight_loader,
'AquilaForCausalLM': llama_dtensor_weight_loader,
'Phi3ForCausalLM': llama_dtensor_weight_loader,
'GemmaForCausalLM': gemma_dtensor_weight_loader,
'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights,
'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights,
'Qwen2ForCausalLM': qwen2_dtensor_weight_loader
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass

View File

@@ -0,0 +1,91 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict, Union, Optional, Iterable, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
def update_hf_weight_loader():
from vllm.model_executor.models.gemma import GemmaForCausalLM
GemmaForCausalLM.load_weights = gemma_load_weights
def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
norm_weight = loaded_weight + 1.0 # prevent inplace modify actor weights
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, norm_weight)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError("Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
vllm_model = vllm_model.cuda()

View File

@@ -0,0 +1,306 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from typing import Dict, List, Optional, Tuple, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PretrainedConfig
import torch.nn as nn
from .arg_utils import EngineArgs
from .llm_engine_sp import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: A HuggingFace Transformers model instance.
tokenizer: A HuggingFace Transformers tokenizer instance.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
model_hf_config: PretrainedConfig,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = None,
disable_custom_all_reduce: bool = False,
load_format = 'auto',
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model_hf_config=model_hf_config,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
load_format=load_format,
**kwargs,
)
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
def init_cache_engine(self):
self.llm_engine.init_cache_engine()
def free_cache_engine(self):
self.llm_engine.free_cache_engine()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params, list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
"must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
if not isinstance(token_ids, list):
# NOTE(shengguangming): convert the rollout input into List[str]
token_ids = self._pre_process_inputs(token_ids)
self._add_request(
prompt,
sampling_params[i] if isinstance(sampling_params, list) else sampling_params,
token_ids,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(type=multi_modal_data.type, data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts", dynamic_ncols=True)
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
# TODO(shengguangming): maybe we can hack the autoregressive logics without only apply post process for better performance
return self._post_process_outputs(outputs)
# NOTE(shengguangming): add for verl
# TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
# NOTE(shengguangming): add for verl
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
output_token_ids = []
logprobs = []
for request_output in request_outputs: # List[RequestOutput]
outputs = request_output.outputs
for output in outputs: # List[CompletionOutput], usually len == 1
output_token_ids.append(torch.tensor(output.token_ids))
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
logprobs_dicts = output.logprobs
if logprobs_dicts is not None:
logprob = []
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
logprob.append(logprobs_dict[id].logprob)
logprobs.append(torch.tensor(logprob))
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
if len(logprobs) > 0:
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
return output_token_ids, logprobs
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.llm_engine.offload_model_weights()

View File

@@ -0,0 +1,283 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
import torch
from typing import Dict, Optional, Union, Type
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.engine.metrics import StatLogger
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message)
from vllm.utils import Counter
from vllm.engine.llm_engine import _load_generation_config_dict
from vllm.engine.llm_engine import LLMEngine
import torch.nn as nn
from .arg_utils import EngineArgs
from .tokenizer import TokenizerGroup
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine(LLMEngine):
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model: the actor model initialize outside vllm (add for verl)
tokenizer: the initialized tokenizer (add for verl)
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
"""
def __init__(
self,
# NOTE(sgm): first two arguments are added for verl
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: nn.Module,
# NOTE(sgm): vllm original arguments
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
# model_config.tokenizer_mode,
model_config.revision,
model_config.tokenizer_revision,
# model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
model_config.seed,
# model_config.served_model_name,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config # TODO: currently is hfconfig
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats
# self.model = model # should not store the model, it should be deleted
# TODO(shengguangming): maybe we can choose init here or from arguments
if not self.model_config.skip_tokenizer_init:
# TODO: check tokenizer class
self._init_tokenizer(tokenizer)
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.seq_counter = Counter()
# TODO: don't know what's the usage
self.generation_config_fields = _load_generation_config_dict(model_config)
self.model_executor = executor_class(
model=model, # add for spmd_gpu_executor
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
)
# Profile the memory usage and initialize the cache.
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype": str(model_config.dtype),
"tensor_parallel_size": parallel_config.tensor_parallel_size,
"block_size": cache_config.block_size,
"gpu_memory_utilization": cache_config.gpu_memory_utilization,
# Quantization
"quantization": model_config.quantization,
"kv_cache_dtype": cache_config.cache_dtype,
# Feature flags
"enable_lora": bool(lora_config),
"enable_prefix_caching": cache_config.enable_prefix_caching,
"enforce_eager": model_config.enforce_eager,
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# NOTE(shengguangming): each process will have independent scheduler
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
# TODO(sgm): add for verl but we may not tokenizer in Rollout
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
init_kwargs = dict(enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs)
def init_cache_engine(self):
# TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache
# Re-capture CUDAGraph would be time-consuming
self.model_executor.init_cache_engine()
def free_cache_engine(self):
self.model_executor.free_cache_engine()
# NOTE(sgm): currently, we only support GPU executor
# The GPUExecutor remove the Ray dependency
@classmethod
def from_engine_args(
cls,
model,
tokenizer,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
# Initialize the cluster and specify the executor class.
assert engine_config.device_config.device_type == "cuda", \
"Currently, the vllm in verl only support running on GPU"
if engine_config.parallel_config.world_size == 1:
engine_config.load_config.load_format = "dummy_hf"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
# Create the LLM engine.
engine = cls(
model,
tokenizer,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)
return engine
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.model_executor.offload_model_weights()

View File

@@ -0,0 +1,348 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict
import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.activation import ScaledActivation
from vllm.model_executor.models import ModelRegistry
# NOTE(shengguangming): replace the origin weight loader function in the class
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Parallel Linear weight loader."""
assert param.size() == loaded_weight.size(
), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
param.size(), loaded_weight.size())
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
# TODO: check megatron
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
(
'input_layernorm',
'input_layernorm',
),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if 'layers' in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace('decoder', 'model')
megatron_name_list = megatron_name.split('.')
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = '.'.join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = '.'.join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
(
'input_layernorm',
'input_layernorm',
),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if 'layers' in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace('decoder', 'model')
megatron_name_list = megatron_name.split('.')
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = '.'.join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = '.'.join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# TODO: need to implement a general way to deal with prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
ColumnParallelLinear: parallel_weight_loader,
MergedColumnParallelLinear: parallel_weight_loader,
QKVParallelLinear: parallel_weight_loader,
RowParallelLinear: parallel_weight_loader,
VocabParallelEmbedding: parallel_weight_loader,
ParallelLMHead: parallel_weight_loader
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
# "default_weight_loader": default_weight_loader
}
# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
# # setattr(layer_class, 'megatron_weight_loader', weight_loader)
# layer_class.weight_loader = weight_loader
__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
'GPT2LMHeadModel': gpt2_weight_loader,
'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron
'LLaMAForCausalLM': llama_megatron_core_te_weight_loader,
'MistralForCausalLM': mistral_megatron_weight_loader,
}
# the actor model is .state_dict()
# Load megatron weights
def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def update_megatron_weight_loader():
for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
layer_class.weight_loader = weight_loader
VocabParallelEmbedding.__init__ = vocab_init
# FIXME(shengguangming): the vLLM vocab will pad to 64, which may incur out of bounds
# so we need to rewrite the init function of vocab
DEFAULT_VOCAB_PADDING_SIZE = 64
def vocab_init(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
# TODO (pad to be divided by 4)
self.num_embeddings = num_embeddings
self.org_vocab_size = org_num_embeddings or num_embeddings
# self.num_embeddings_padded = pad_vocab_size(num_embeddings,
# padding_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
# TODO: remove dependencies from megatron
from megatron.core.tensor_parallel.utils import VocabUtility
self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
# device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})

View File

@@ -0,0 +1,265 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
"""Utilities for selecting and loading models."""
from typing import Dict, Union, Optional, Iterable, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from .config import ModelConfig, LoadFormat, LoadConfig
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader
def get_model(actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
loader = get_model_loader(load_config)
if load_config.load_format.startswith('dummy'):
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)
else:
return loader.load_model(actor_model=actor_model,
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.AUTO:
update_megatron_weight_loader()
return MegatronLoader(load_config)
# NOTE(sgm): change the weight_loader function in runtime
if load_config.load_format == LoadFormat.MEGATRON:
update_megatron_weight_loader()
return MegatronLoader(load_config)
if load_config.load_format == LoadFormat.HF:
update_hf_weight_loader()
return HFLoader(load_config)
if load_config.load_format == LoadFormat.DTENSOR:
update_dtensor_weight_loader()
return DTensorLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_HF:
update_hf_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
update_megatron_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
update_dtensor_weight_loader()
return DummyModelLoader(load_config)
raise ValueError('load format not supported in verl: {}, only support {} and {}'.format(
load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# initialize_dummy_weights(model)
return model.eval()
class MegatronLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(self, actor_model: Union[PreTrainedModel,
Dict], model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_megatron_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class HFLoader(BaseModelLoader):
"""Model loader that can load the model weights from model's full params."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
if isinstance(actor_model, Dict):
return actor_model.items()
elif isinstance(actor_model, nn.Module):
return dict(actor_model.named_parameters()).items()
else:
raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}')
def load_model(self, actor_model: Union[PreTrainedModel,
Dict], model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
# with torch.device(device_config.device):
# NOTE(sgm): init the model in cpu
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
model.load_weights(self._get_weights_iterator(actor_model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class DTensorLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(self, actor_model: Union[PreTrainedModel,
Dict], model_config: ModelConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_dtensor_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
# FIXME(sgm): hack the _get_logits function in vllm v0.4.2
# as they use ray, the _get_logits result will only need to return to the driver node,
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
# all_gather is required (aligned with v0.2.6)
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
from vllm.model_executor.layers.logits_processor import LogitsProcessor
LogitsProcessor._get_logits = _get_logits

View File

@@ -0,0 +1,281 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
import torch
import torch.nn as nn
from enum import IntEnum
from typing import Dict, List, Optional, Set, Tuple, Union
from vllm.attention import (AttentionMetadata, get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available)
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner
from .model_loader import get_model
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
vision_language_config: Optional[VisionLanguageConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.load_config = load_config
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None)
self.device_config = (device_config if device_config is not None else DeviceConfig())
self.device = self.device_config.device
# NOTE(sgm): add for verl
self.model = model # this will be replaced by get_model()
# Set after load_model.
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool: Optional[Tuple[int, int]] = None # Set during graph capture.
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config
self.attn_backend = get_attn_backend(self.model_config.dtype if model_config is not None else None)
# Lazy initialization
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
# NOTE(sgm): initialize model using the actor model
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(actor_model=self.model,
model_config=self.model_config,
device_config=self.device_config,
lora_config=self.lora_config,
load_config=self.load_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
vision_language_config=self.vision_language_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
if self.lora_config:
assert hasattr(self.model, "supported_lora_modules") and self.model.supported_lora_modules, (
"Model does not support LoRA")
assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.", self.model.__class__)
else:
logger.warning("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping,
torch.Tensor]:
# NOTE(sgm): all workers prepare the input in the same way
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device)
input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
# if not self.is_driver_worker:
# return None
# TODO(sgm): perform sampling on rank 0
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output

View File

@@ -0,0 +1,294 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
import torch
import torch.distributed
from typing import Optional
import vllm.distributed.parallel_state as ps
import vllm.envs as envs
from vllm.logger import init_logger
from torch.distributed.device_mesh import init_device_mesh
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Device mesh for using DTensor
_DEVICE_MESH = None
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
# This method is for initializing the ParallelGroup when using HybridEngine
def initialize_parallel_state(
distributed_init_method: str = "env://",
backend: str = "nccl",
tensor_model_parallel_size: int = 1,
num_tp_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
rank = int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
ps.init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
if torch.distributed.get_world_size() > 1:
# NOTE: build a sepearate inference group with infer tp & micro dp
initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size,
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp)
else:
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
return
assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
# assert (get_pipeline_model_parallel_world_size(
# ) == pipeline_model_parallel_size), (
# "pipeline parallel group already initialized, but of unexpected size: "
# f"{get_pipeline_model_parallel_world_size()=} vs. "
# f"{pipeline_model_parallel_size=}")
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (ps._TP_DEVICE_GROUP is not None)
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
num_tensor_model_parallel_groups_per_train_tp: int = 1) -> None:
from torch.distributed import new_group
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert isinstance(tensor_model_parallel_size, int)
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
# Build the tensor model-parallel groups.
assert ps._TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
global _TP_DEVICE_GROUP
global _TP_CPU_GROUP
global _DEVICE_MESH
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = torch.distributed.get_backend()
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
if num_tensor_model_parallel_groups_per_train_tp == 1:
# if tensor_model_parallel_size == train_tensor_parallel_size:
# using the same tp group as Megatron/vllm
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
ps._TP_DEVICE_GROUP = group
ps._TP_CPU_GROUP = cpu_group
# no _MICRO_DATA_PARALLEL_GROUP
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
# train_tp = train_tensor_parallel_size
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
group = torch.distributed.new_group(ranks)
cpu_group = torch.distributed.new_group(ranks, backend='gloo')
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
ps._TP_DEVICE_GROUP = _TP_DEVICE_GROUP
ps._TP_CPU_GROUP = cpu_group
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
NOTE: This method is a hack from the open-sourced version without
asertion of world_size = tp * pp
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
# NOTE(sgm) we don't assert world_size == tp * pp
# DP is not managed by vllm but by the veRL WorkerGroup
num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size)
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
rank = torch.distributed.get_rank()
# Build device mesh for TP
if num_tensor_model_parallel_groups > 1:
device_mesh = init_device_mesh("cuda", (num_tensor_model_parallel_groups, tensor_model_parallel_size),
mesh_dim_names=("replicate", "tp_shard"))
else:
device_mesh = init_device_mesh("cuda", (tensor_model_parallel_size,), mesh_dim_names=["tp_shard"])
shard_group = device_mesh.get_group(mesh_dim="tp_shard")
# Build the tensor model-parallel groups.
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
global _DEVICE_MESH
assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
assert _DEVICE_MESH is None, ("device mesh in vllm is already initialized")
_DEVICE_MESH = device_mesh
# for i in range(num_tensor_model_parallel_groups):
# ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
# group = torch.distributed.new_group(ranks, backend=backend)
# cpu_group = torch.distributed.new_group(ranks, backend="gloo")
# assert torch.distributed.get_process_group_ranks(shard_group) == torch.distributed.get_process_group_ranks(cpu_group)
# ranks = torch.distributed.get_process_group_ranks(shard_group)
# cpu_group = torch.distributed.new_group(ranks, backend="gloo") # TODO: this will hang
# cpu_group = torch.distributed.new_group(, backend="gloo")
# if rank == 0:
# print(f'rank: {rank}')
# print(f'ranks: {ranks}')
# print(f'torch.distributed.get_process_group_ranks(shard_group): {torch.distributed.get_process_group_ranks(shard_group)}')
# if rank in ranks:
_TP_DEVICE_GROUP = shard_group
ps._TP_DEVICE_GROUP = _TP_DEVICE_GROUP
# ps._TP_CPU_GROUP = cpu_group # TODO: will hang when used with device mesh
# TODO: init using device mesh
# Build the pipeline model-parallel groups.
assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks:
ps._PIPELINE_MODEL_PARALLEL_GROUP = group
ps._PIPELINE_GLOBAL_RANKS = ranks
"""
Device mesh utilities
"""
def get_device_mesh():
assert _DEVICE_MESH is not None, ("device mesh is not initialized")
return _DEVICE_MESH
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP_DEVICE_GROUP is not None, ("tensor model parallel group is not initialized")
return _TP_DEVICE_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

View File

@@ -0,0 +1,218 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
import os
import socket
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, ExecuteModelRequest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
class SPMDGPUExecutor(ExecutorBase):
"""SPMD-based multi-GPU executor implementations."""
def __init__(
self,
model, # pytorch model itself or its parameter dict
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
distributed_init_method = initialize_cluster(parallel_config)
self._init_executor(model, distributed_init_method)
# TODO(sgm): verl not support speculative decode now
def _init_executor(self, model, distributed_init_method) -> None:
assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend."
# Create the parallel worker for each GPU.
self._init_workers_sp(model, distributed_init_method)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
print(f'local rank {local_rank}')
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_device()
self.worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.determine_num_available_blocks()
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if torch.distributed.get_rank() == 0:
print(
f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
if torch.distributed.get_rank() == 0:
print(
f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
def init_cache_engine(self) -> None:
self.worker._init_cache_engine()
def free_cache_engine(self) -> None:
self.worker.free_cache_engine()
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
# NOTE(sgm):
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
# In vllm with ray, only the driver worker returns the sampling results.
return all_outputs
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request=lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id=lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def check_health(self) -> None:
# SPMDExecutor will always be healthy as long as
# it's running.
return
# NOTE(sgm): add for verl
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
Returns:
The `distributed_init_method` is the address for initializing the
distributed backend.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
distributed_init_method = 'env://'
return distributed_init_method
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# TODO(sgm): not implemented async executor yet
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()

View File

@@ -0,0 +1,77 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
# TODO(sgm): the lora tokenizer is also passed, but may be different
tokenizer = self.tokenizer
# tokenizer = (get_lora_tokenizer(
# lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

View File

@@ -0,0 +1,292 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import os
import gc
from typing import Dict, List, Tuple, Optional, Union
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (init_custom_ar)
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
from vllm.distributed import get_tensor_model_parallel_cpu_group, init_distributed_environment, get_tensor_model_parallel_group
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
from .model_runner import ModelRunner
from .megatron_weight_loaders import load_megatron_weights
from .hf_weight_loader import load_hf_weights
from .dtensor_weight_loaders import load_dtensor_weights
from .parallel_state import (ensure_model_parallel_initialized)
from .config import ModelConfig, LoadConfig, LoadFormat
class Worker(Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
is_driver_worker: bool = False,
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.vision_language_config = vision_language_config
if self.vision_language_config:
assert not self.lora_config, ("To be tested: vision language model with LoRA settings.")
self.model_runner = ModelRunner(
model,
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
vision_language_config=vision_language_config,
)
# Uninitialized cache engine. Will be initialized by
# init_cache_engine.
self.cache_engine: CacheEngine = None
self.gpu_cache: List[torch.Tensor] = None
# NOTE(sgm): For offloading inference engine params
self.cpu_model = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
self.parallel_config.world_size = world_size
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
# NOTE(sgm) use the remaining memory
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
# NOTE(sgm): Add for verl, synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group())
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _init_cache_engine(self):
if self.cache_engine is None and self.gpu_cache is None:
super()._init_cache_engine()
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
@torch.inference_mode()
def execute_model(self, execute_model_req: Optional[ExecuteModelRequest] = None) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# NOTE(sgm): each SPMD rank will have identical input
assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list)
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
blocks_to_copy = execute_model_req.blocks_to_copy
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache)
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return [output]
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict, load_format: str):
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
load_megatron_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.HF:
# full model state dict without no sharding
load_hf_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.DTENSOR:
load_dtensor_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device='cpu')
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = "env://",
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size)
# TODO(sgm): check whether need this
# if pynccl_utils.is_initialized():
# pynccl_world_size = pynccl_utils.get_world_size()
# if pynccl_world_size != parallel_config.world_size:
# raise RuntimeError(
# "pynccl is already initialized but the pynccl world "
# "size does not match parallel_config.world_size "
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
# elif parallel_config.world_size > 1:
# # NOTE(woosuk): We don't initialize pynccl process group when world size
# # is 1.
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
# pynccl_utils.init_process_group(
# group=get_tensor_model_parallel_cpu_group())
# # Initialize a custom fast all-reduce implementation.
# if not parallel_config.disable_custom_all_reduce:
# init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# if pynccl_utils.is_initialized():
# pynccl_utils.all_reduce(torch.zeros(1).cuda())

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,453 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
import os
import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
import torch.nn as nn
from transformers import PretrainedConfig
from .config import ModelConfig, LoadConfig
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (BaseTokenizerGroup)
logger = init_logger(__name__)
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model_hf_config: PretrainedConfig = None # for verl
served_model_name = None # TODO(sgm): check this
# tokenizer: Optional[str] = None # TODO(sgm): check this
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
cpu_offload_gb: int = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: str = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
# TODO(shengguangming): delete the unused args
parser.add_argument('--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
parser.add_argument('--dtype',
type=str,
default=EngineArgs.dtype,
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_config(
self,
) -> EngineConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError("BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError("BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
assert self.cpu_offload_gb >= 0, ("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig()
device_config = DeviceConfig(self.device)
# NOTE(sgm): we only modify ModelConfig, other configs are import from vllm
model_config = ModelConfig(hf_config=self.model_hf_config,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
multimodal_config=multimodal_config)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
)
parallel_config = ParallelConfig(pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
worker_use_ray=self.worker_use_ray,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend)
# NOTE[VERL]: Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
parallel_config.world_size = world_size
max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
if self.enable_chunked_prefill is None:
# If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.
if use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window() is not None)
use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (model_config.contains_seqlen_agnostic_layers(parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and
not self.enable_prompt_adapter and not self.enable_prefix_caching and
not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True
logger.warning("Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False
if not self.enable_chunked_prefill and use_long_context:
logger.warning(
"The model has a long context length (%s). This may cause OOM "
"errors during the initial memory profiling phase, or result "
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
# TODO: spec config
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
)
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
None) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config["qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend)
observability_config = ObservabilityConfig(otlp_traces_endpoint=self.otlp_traces_endpoint)
if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and
not scheduler_config.use_v2_block_manager):
raise ValueError("Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")
return EngineConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
)

View File

@@ -0,0 +1,246 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
import enum
import json
from typing import List, Optional, Union
from dataclasses import dataclass, field, fields
import torch
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils import is_hip, print_warning_once
# Add for verl
from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len, get_served_model_name
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
logger = init_logger(__name__)
_GB = 1 << 30
class ModelConfig(ModelConfig):
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
def __init__(
self,
hf_config: PretrainedConfig,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
) -> None:
self.model = hf_config._name_or_path
self.tokenizer = hf_config._name_or_path
# NOTE(sgm): same as open-sourced
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
if max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
# self.hf_config = get_config(model, trust_remote_code, revision)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# self.served_model_name = get_served_model_name(model,
# served_model_name)
# self._verify_load_format()
# self._verify_tokenizer_mode()
if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and
self.hf_text_config.sliding_window is not None):
print_warning_once("Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True
self.max_model_len = _get_and_verify_max_len(hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(
self.model, # str
served_model_name)
self.multimodal_config = multimodal_config
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
self._verify_quantization()
self._verify_cuda_graph()
class LoadFormat(str, enum.Enum):
AUTO = 'auto'
MEGATRON = "megatron"
HF = "hf"
DTENSOR = 'dtensor'
DUMMY_HF = 'dummy_hf'
DUMMY_MEGATRON = 'dummy_megatron'
DUMMY_DTENSOR = 'dummy_dtensor'
# TODO: check whether this is necessary
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
]
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")

View File

@@ -0,0 +1,340 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict, Iterable, Tuple
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor, Shard, Replicate
from vllm.model_executor.layers.linear import *
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import is_pp_missing_parameter
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
from vllm.model_executor.layers.fused_moe import FusedMoE
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=vllm_model.config.n_routed_experts)
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param,
local_loaded_weight.to(dtype=param.dtype),
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
pass
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert param_name in parallelize_plan.keys(), \
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
placements=placement).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split('.')
# Reconstruct the string without 'model.layers.x.'
name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
'GPT2LMHeadModel': gpt2_dtensor_weight_loader,
'LlamaForCausalLM': llama_dtensor_weight_loader,
'LLaMAForCausalLM': llama_dtensor_weight_loader,
'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
'InternLMForCausalLM': llama_dtensor_weight_loader,
'AquilaModel': llama_dtensor_weight_loader,
'AquilaForCausalLM': llama_dtensor_weight_loader,
'Phi3ForCausalLM': llama_dtensor_weight_loader,
'GemmaForCausalLM': gemma_dtensor_weight_loader,
'Gemma2ForCausalLM': gemma_dtensor_weight_loader,
'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights,
'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights,
'Qwen2ForCausalLM': qwen2_dtensor_weight_loader,
'DeepseekV2ForCausalLM': deepseekv2_dtensor_weight_loader
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass

View File

@@ -0,0 +1,44 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict, Union, Optional, Iterable, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
def update_hf_weight_loader():
print('no hf weight loader need to be updated')
return
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
vllm_model = vllm_model.cuda()

View File

@@ -0,0 +1,239 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload, Dict, Tuple
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PretrainedConfig
import torch.nn as nn
from .arg_utils import EngineArgs
from .llm_engine_sp import LLMEngine
from vllm import LLM
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
class LLM(LLM):
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: A HuggingFace Transformers model instance.
tokenizer: A HuggingFace Transformers tokenizer instance.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
model_hf_config: PretrainedConfig,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
skip_tokenizer_init: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_format = 'auto',
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model_hf_config=model_hf_config,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
load_format=load_format,
skip_tokenizer_init=skip_tokenizer_init,
**kwargs,
)
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
def init_cache_engine(self):
self.llm_engine.init_cache_engine()
def free_cache_engine(self):
self.llm_engine.free_cache_engine()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict["elapsed"]
pbar.postfix = (f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return self._post_process_outputs(outputs)
# # NOTE(shengguangming): add for verl
# # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
# def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
# # remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
# non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
# token_ids = prompt_token_ids[non_pad_index:].tolist()
# return token_ids
# NOTE(shengguangming): add for verl
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
output_token_ids = []
logprobs = []
for request_output in request_outputs: # List[RequestOutput]
outputs = request_output.outputs
for output in outputs: # List[CompletionOutput], usually len == 1
output_token_ids.append(torch.tensor(output.token_ids))
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
logprobs_dicts = output.logprobs
if logprobs_dicts is not None:
logprob = []
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
logprob.append(logprobs_dict[id].logprob)
logprobs.append(torch.tensor(logprob))
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
if len(logprobs) > 0:
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
return output_token_ids, logprobs
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.llm_engine.offload_model_weights()

View File

@@ -0,0 +1,328 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
import torch
from typing import Dict, Optional, Union, Type
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger, StatLoggerBase, Stats)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message)
from vllm.utils import Counter
from vllm.engine.llm_engine import _load_generation_config_dict
from vllm.engine.llm_engine import LLMEngine
from vllm.version import __version__ as VLLM_VERSION
import torch.nn as nn
from .arg_utils import EngineArgs
from .tokenizer import TokenizerGroup
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine(LLMEngine):
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model: the actor model initialize outside vllm (add for verl)
tokenizer: the initialized tokenizer (add for verl)
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
"""
def __init__(
self,
# NOTE(sgm): first two arguments are added for verl
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: nn.Module,
# NOTE(sgm): vllm original arguments
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, revision=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.revision,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
cache_config.enable_prefix_caching,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig()
self.log_stats = log_stats
# self.model = model # should not store the model, it should be deleted
# TODO(shengguangming): maybe we can choose init here or from arguments
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer(tokenizer)
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(self.model_config)
self.model_executor = executor_class(
model=model, # add for spmd_gpu_executor
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
)
# Profile the memory usage and initialize the cache.
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype": str(model_config.dtype),
"tensor_parallel_size": parallel_config.tensor_parallel_size,
"block_size": cache_config.block_size,
"gpu_memory_utilization": cache_config.gpu_memory_utilization,
# Quantization
"quantization": model_config.quantization,
"kv_cache_dtype": str(cache_config.cache_dtype),
# Feature flags
"enable_lora": bool(lora_config),
"enable_prompt_adapter": bool(prompt_adapter_config),
"enable_prefix_caching": cache_config.enable_prefix_caching,
"enforce_eager": model_config.enforce_eager,
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
self.stat_loggers = {
"logging":
LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config", self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
# TODO(sgm): add for verl but we may not tokenizer in Rollout
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
init_kwargs = dict(enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None)
init_kwargs.update(tokenizer_init_kwargs)
return TokenizerGroup(tokenizer, **init_kwargs)
def init_cache_engine(self):
# TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache
# Re-capture CUDAGraph would be time-consuming
self.model_executor.init_cache_engine()
def free_cache_engine(self):
self.model_executor.free_cache_engine()
# NOTE(sgm): currently, we only support GPU executor
# The GPUExecutor remove the Ray dependency
@classmethod
def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]:
assert engine_config.device_config.device_type == "cuda", \
"Currently, the vllm in verl only support running on GPU"
if engine_config.parallel_config.world_size == 1:
engine_config.load_config.load_format = "dummy_hf"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
return executor_class
@classmethod
def from_engine_args(
cls,
model,
tokenizer,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Initialize the cluster and specify the executor class.
assert engine_config.device_config.device_type == "cuda", \
"Currently, the vllm in verl only support running on GPU"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
# Create the LLM engine.
engine = cls(
model,
tokenizer,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.model_executor.offload_model_weights()

View File

@@ -0,0 +1,307 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
from typing import Dict
import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.activation import ScaledActivation
from vllm.model_executor.models import ModelRegistry
# NOTE(shengguangming): replace the origin weight loader function in the class
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Parallel Linear weight loader."""
assert param.size() == loaded_weight.size(
), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
param.size(), loaded_weight.size())
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
# TODO: check megatron
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
(
'input_layernorm',
'input_layernorm',
),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if 'layers' in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace('decoder', 'model')
megatron_name_list = megatron_name.split('.')
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = '.'.join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = '.'.join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", 'self_attn.o_proj'),
(
'input_layernorm',
'input_layernorm',
),
('pre_mlp_layernorm', 'post_attention_layernorm'),
('mlp.linear_fc1', 'mlp.gate_up_proj'),
('mlp.linear_fc2', 'mlp.down_proj'),
('decoder.final_layernorm', 'model.norm'),
('output_layer', 'lm_head'),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith('.bias') and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if 'layers' in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace('decoder', 'model')
megatron_name_list = megatron_name.split('.')
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = '.'.join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = '.'.join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# TODO: need to implement a general way to deal with prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
ColumnParallelLinear: parallel_weight_loader,
MergedColumnParallelLinear: parallel_weight_loader,
QKVParallelLinear: parallel_weight_loader,
RowParallelLinear: parallel_weight_loader,
VocabParallelEmbedding: parallel_weight_loader,
ParallelLMHead: parallel_weight_loader
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
# "default_weight_loader": default_weight_loader
}
# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
# # setattr(layer_class, 'megatron_weight_loader', weight_loader)
# layer_class.weight_loader = weight_loader
__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
'GPT2LMHeadModel': gpt2_weight_loader,
'LlamaForCausalLM': llama_megatron_weight_loader, # use te backend for open-source megatron
'LLaMAForCausalLM': llama_megatron_weight_loader,
'MistralForCausalLM': mistral_megatron_weight_loader,
}
# the actor model is .state_dict()
# Load megatron weights
def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def update_megatron_weight_loader():
for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
layer_class.weight_loader = weight_loader

View File

@@ -0,0 +1,302 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict, Union, Optional, Iterable, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from .config import ModelConfig, LoadFormat, LoadConfig
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader
def get_model(actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig = None) -> nn.Module:
loader = get_model_loader(load_config)
if load_config.load_format.startswith('dummy'):
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
else:
return loader.load_model(actor_model=actor_model,
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.AUTO:
update_megatron_weight_loader()
return MegatronLoader(load_config)
# NOTE(sgm): change the weight_loader function in runtime
if load_config.load_format == LoadFormat.MEGATRON:
update_megatron_weight_loader()
return MegatronLoader(load_config)
if load_config.load_format == LoadFormat.HF:
update_hf_weight_loader()
return HFLoader(load_config)
if load_config.load_format == LoadFormat.DTENSOR:
update_dtensor_weight_loader()
return DTensorLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_HF:
update_hf_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
update_megatron_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
update_dtensor_weight_loader()
return DummyModelLoader(load_config)
raise ValueError('load format not supported in verl: {}, only support {} and {}'.format(
load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# initialize_dummy_weights(model)
return model.eval()
class MegatronLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_megatron_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class HFLoader(BaseModelLoader):
"""Model loader that can load the model weights from model's full params."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
if isinstance(actor_model, Dict):
return actor_model.items()
elif isinstance(actor_model, nn.Module):
return dict(actor_model.named_parameters()).items()
else:
raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}')
def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
# with torch.device(device_config.device):
# NOTE(sgm): init the model in cpu
model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
scheduler_config)
model.load_weights(self._get_weights_iterator(actor_model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class DTensorLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig,
device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config,
scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_dtensor_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
# FIXME(sgm): hack the _get_logits function in vllm v0.4.2
# as they use ray, the _get_logits result will only need to return to the driver node,
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
# all_gather is required (aligned with v0.2.6)
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
from vllm.model_executor.layers.logits_processor import LogitsProcessor
def logitsprocessor_init(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super(LogitsProcessor, self).__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = False
LogitsProcessor.__init__ = logitsprocessor_init # use all_gather

View File

@@ -0,0 +1,150 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
import torch
import torch.nn as nn
from enum import IntEnum
from typing import Dict, List, Optional, Set, Tuple, Union
import warnings
import vllm.envs as envs
from vllm.attention import (AttentionMetadata, get_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.models.interfaces import (supports_lora, supports_vision)
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available)
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner
from vllm.prompt_adapter.worker_manager import (LRUCacheWorkerPromptAdapterManager)
from .model_loader import get_model
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config,
kv_cache_dtype,
is_driver_worker=True, # a hack
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
return_hidden_states=return_hidden_states)
# NOTE(sgm): add for verl
self.model = model # this will be replaced by get_model()
# NOTE(sgm): initialize model using the actor model
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
self.model = get_model(actor_model=self.model,
model_config=self.model_config,
device_config=self.device_config,
lora_config=self.lora_config,
load_config=self.load_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_vision(self.model), "To be tested: vision language model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=self.model.config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (self.prompt_adapter_manager.create_prompt_adapter_manager(self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.", self.model.__class__)
else:
logger.warning("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self.model = torch.compile(self.model, fullgraph=True, backend="eager")

View File

@@ -0,0 +1,303 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
import torch
import torch.distributed
from typing import Optional
import vllm.distributed.parallel_state as ps
from vllm.distributed.parallel_state import get_pp_group, get_world_group, init_distributed_environment, init_model_parallel_group
import vllm.envs as envs
from vllm.logger import init_logger
from torch.distributed.device_mesh import init_device_mesh
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Device mesh for using DTensor
_DEVICE_MESH = None
# Tensor model parallel group that the current rank belongs to.
_TP = None
# Pipeline model parallel group that the current rank belongs to.
_PP = None
# This method is for initializing the ParallelGroup when using HybridEngine
def initialize_parallel_state(
distributed_init_method: str = "env://",
backend: str = "nccl",
tensor_model_parallel_size: int = 1,
num_tp_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
rank = int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
if torch.distributed.get_world_size() > 1:
# NOTE: build a sepearate inference group with infer tp & micro dp
initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size,
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp)
else:
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
return
assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
pp_world_size = get_pp_group().world_size
assert (pp_world_size == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: "
f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}")
# TODO(sgm): deviate from the v0.5.4, not pp now
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (ps._TP is not None)
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
num_tensor_model_parallel_groups_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1) -> None:
from torch.distributed import new_group
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert isinstance(tensor_model_parallel_size, int)
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
# Build the tensor model-parallel groups.
assert ps._TP is None, ("tensor model parallel group is already initialized")
global _TP
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = torch.distributed.get_backend()
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
if num_tensor_model_parallel_groups_per_train_tp == 1:
# if tensor_model_parallel_size == train_tensor_parallel_size:
# using the same tp group as Megatron/vllm
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
# train_tp = train_tensor_parallel_size
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
global _PP
assert _PP is None, ("pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
NOTE: This method is a hack from the open-sourced version without
asertion of world_size = tp * pp
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)
# NOTE(sgm) we don't assert world_size == tp * pp
# DP is not managed by vllm but by the veRL WorkerGroup
# if (world_size !=
# tensor_model_parallel_size * pipeline_model_parallel_size):
# raise RuntimeError(
# f"world_size ({world_size}) is not equal to "
# f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
# f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size)
rank = torch.distributed.get_rank()
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
global _PP
assert _PP is None, ("pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
"""
Device mesh utilities
"""
def get_device_mesh():
assert _DEVICE_MESH is not None, ("device mesh is not initialized")
return _DEVICE_MESH
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP.device_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

View File

@@ -0,0 +1,253 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
import os
import socket
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, ExecuteModelRequest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
class SPMDGPUExecutor(ExecutorBase):
"""SPMD-based multi-GPU executor implementations."""
def __init__(
self,
model, # pytorch model itself or its parameter dict
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
distributed_init_method = initialize_cluster(parallel_config)
self._init_executor(model, distributed_init_method)
# TODO(sgm): verl not support speculative decode now
def _init_executor(self, model, distributed_init_method) -> None:
assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend."
# Create the parallel worker for each GPU.
self._init_workers_sp(model, distributed_init_method)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
print(f'local rank {local_rank}')
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=None,
prompt_adapter_config=self.speculative_config,
is_driver_worker=True,
model_runner_cls=None, # use the default one
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_device()
self.worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.determine_num_available_blocks()
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if torch.distributed.get_rank() == 0:
print(
f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
if torch.distributed.get_rank() == 0:
print(
f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
def init_cache_engine(self) -> None:
self.worker._init_cache_engine()
def free_cache_engine(self) -> None:
self.worker.free_cache_engine()
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
# NOTE(sgm):
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
# In vllm with ray, only the driver worker returns the sampling results.
return all_outputs
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request=lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id=lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def check_health(self) -> None:
# SPMDExecutor will always be healthy as long as
# it's running.
return
# NOTE(sgm) add for verl to pass the abstract class test, not used
from vllm.prompt_adapter.request import PromptAdapterRequest
def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.add_prompt_adapter(prompt_adapter_request)
def list_prompt_adapters(self) -> Set[int]:
return self.worker.list_prompt_adapters()
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.pin_lora(lora_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.pin_prompt_adapter(prompt_adapter_id)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.remove_prompt_adapter(prompt_adapter_id)
# NOTE(sgm): add for verl
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
Returns:
The `distributed_init_method` is the address for initializing the
distributed backend.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
distributed_init_method = 'env://'
return distributed_init_method
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# TODO(sgm): not implemented async executor yet
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()

View File

@@ -0,0 +1,77 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
# TODO(sgm): the lora tokenizer is also passed, but may be different
tokenizer = self.tokenizer
# tokenizer = (get_lora_tokenizer(
# lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

View File

@@ -0,0 +1,323 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import os
import gc
from typing import Dict, List, Tuple, Optional, Union, Type
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.model_executor import set_random_seed
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput)
from vllm.worker.cache_engine import CacheEngine
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
from vllm.distributed import (init_distributed_environment, set_custom_all_reduce, get_tensor_model_parallel_group)
from vllm.worker.worker_base import WorkerInput
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase
from .model_runner import ModelRunner
from .megatron_weight_loaders import load_megatron_weights
from .hf_weight_loader import load_hf_weights
from .dtensor_weight_loaders import load_dtensor_weights
from .parallel_state import (ensure_model_parallel_initialized)
from .config import ModelConfig, LoadConfig, LoadFormat
class Worker(Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker # TODO: we don't need driver
# if parallel_config and is_driver_worker:
# assert rank % parallel_config.tensor_parallel_size == 0, \
# "Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.multimodal_config = multimodal_config
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator"]) \
else {"return_hidden_states": True}
# TODO(sgm): set correct model runner class
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model, # [VERL]: add for verl
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine] = None
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
# NOTE(sgm): [VERL] For offloading inference engine params
self.cpu_model = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
self.parallel_config.world_size = world_size
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
# NOTE(sgm) [VERL] use the remaining memory
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
# NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _init_cache_engine(self):
if self.cache_engine is None and self.gpu_cache is None:
super()._init_cache_engine()
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
# NOTE(sgm): [VERL]: adapt from _execute_model_spmd()
def execute_model(self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
# verl.worker.workerbase.WorkerBase
# swap cache
super().execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None,
intermediate_tensors)
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict, load_format: str):
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
load_megatron_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.HF:
# full model state dict without no sharding
load_hf_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.DTENSOR:
load_dtensor_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device='cpu')
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = "env://",
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size)
# TODO(sgm): check whether need this
# if pynccl_utils.is_initialized():
# pynccl_world_size = pynccl_utils.get_world_size()
# if pynccl_world_size != parallel_config.world_size:
# raise RuntimeError(
# "pynccl is already initialized but the pynccl world "
# "size does not match parallel_config.world_size "
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
# elif parallel_config.world_size > 1:
# # NOTE(woosuk): We don't initialize pynccl process group when world size
# # is 1.
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
# pynccl_utils.init_process_group(
# group=get_tensor_model_parallel_cpu_group())
# # Initialize a custom fast all-reduce implementation.
# if not parallel_config.disable_custom_all_reduce:
# init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# if pynccl_utils.is_initialized():
# pynccl_utils.all_reduce(torch.zeros(1).cuda())

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,78 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
import os
from dataclasses import dataclass
from transformers import PretrainedConfig
from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs
from .config import LoadConfig, ModelConfig
@dataclass
class EngineArgs(EngineArgs):
model_hf_config: PretrainedConfig = None # for verl
def __post_init__(self):
pass
def create_model_config(self) -> ModelConfig:
return ModelConfig(
hf_config=self.model_hf_config,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
)
def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
def create_engine_config(self) -> EngineConfig:
engine_config = super().create_engine_config()
# NOTE[VERL]: Use the world_size set by torchrun
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
engine_config.parallel_config.world_size = world_size
return engine_config

View File

@@ -0,0 +1,105 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
import enum
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Union
from transformers import PretrainedConfig
# Add for verl
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip
if TYPE_CHECKING:
from vllm.model_executor.model_loader.loader import BaseModelLoader
logger = init_logger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
MEGATRON = "megatron"
HF = "hf"
DTENSOR = "dtensor"
DUMMY_HF = "dummy_hf"
DUMMY_MEGATRON = "dummy_megatron"
DUMMY_DTENSOR = "dummy_dtensor"
class ModelConfig(ModelConfig):
def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None:
super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs)
self.hf_config = hf_config
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
]
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")

View File

@@ -0,0 +1,380 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch.nn as nn
from torch.distributed._tensor import DTensor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import is_pp_missing_parameter
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
from vllm.model_executor.layers.fused_moe import FusedMoE
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=vllm_model.config.n_routed_experts,
)
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
local_loaded_weight.to(dtype=param.dtype),
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
pass
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert (
param_name
in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
placements=placement).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split(".")
# Reconstruct the string without 'model.layers.x.'
name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
"GPT2LMHeadModel": gpt2_dtensor_weight_loader,
"LlamaForCausalLM": llama_dtensor_weight_loader,
"LLaMAForCausalLM": llama_dtensor_weight_loader,
"MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
"InternLMForCausalLM": llama_dtensor_weight_loader,
"AquilaModel": llama_dtensor_weight_loader,
"AquilaForCausalLM": llama_dtensor_weight_loader,
"Phi3ForCausalLM": llama_dtensor_weight_loader,
"GemmaForCausalLM": gemma_dtensor_weight_loader,
"Gemma2ForCausalLM": gemma_dtensor_weight_loader,
"GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights,
"Starcoder2ForCausalLM": starcoder2_dtensor_load_weights,
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
"DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader,
"Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass

View File

@@ -0,0 +1,41 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch.nn as nn
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
def update_hf_weight_loader():
print("no hf weight loader need to be updated")
return
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
vllm_model = vllm_model.cuda()

View File

@@ -0,0 +1,200 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from vllm import LLM
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.utils import Counter
from .arg_utils import EngineArgs
from .llm_engine_sp import LLMEngine
class LLM(LLM):
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: A HuggingFace Transformers model instance.
tokenizer: A HuggingFace Transformers tokenizer instance.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
model_hf_config: PretrainedConfig,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
skip_tokenizer_init: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_format="auto",
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type")
if any(k in kwargs for k in removed_vision_keys):
raise TypeError("There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs(
model_hf_config=model_hf_config,
# tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
load_format=load_format,
**kwargs,
)
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
def init_cache_engine(self):
self.llm_engine.init_cache_engine()
def free_cache_engine(self):
self.llm_engine.free_cache_engine()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
outputs = super()._run_engine(use_tqdm=use_tqdm)
return self._post_process_outputs(outputs)
# # NOTE(shengguangming): add for verl
# # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
# def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
# # remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
# non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
# token_ids = prompt_token_ids[non_pad_index:].tolist()
# return token_ids
# NOTE(shengguangming): add for verl
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
output_token_ids = []
logprobs = []
for request_output in request_outputs: # List[RequestOutput]
outputs = request_output.outputs
for output in outputs: # List[CompletionOutput], usually len == 1
output_token_ids.append(torch.tensor(output.token_ids))
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
logprobs_dicts = output.logprobs
if logprobs_dicts is not None:
logprob = []
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
logprob.append(logprobs_dict[id].logprob)
logprobs.append(torch.tensor(logprob))
pad_token_id = (self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None
else self.llm_engine.tokenizer.eos_token_id)
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
if len(logprobs) > 0:
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
return output_token_ids, logprobs
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.llm_engine.offload_model_weights()

View File

@@ -0,0 +1,408 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
from functools import partial
from typing import Callable, Dict, Optional, Type, Union
import torch
import torch.nn as nn
from vllm.config import (
CacheConfig,
DecodingConfig,
DeviceConfig,
EngineConfig,
LoadConfig,
LoRAConfig,
ModelConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.sequence import Sequence
from vllm.tracing import init_tracer
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
from vllm.utils import Counter, weak_bind
from vllm.version import __version__ as VLLM_VERSION
from .arg_utils import EngineArgs
from .config import LoadConfig, ModelConfig
from .tokenizer import TokenizerGroup
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine(LLMEngine):
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
def __init__(
self,
# NOTE(sgm): first two arguments are added for verl
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: nn.Module,
# NOTE(sgm): vllm original arguments
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig()
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer(tokenizer)
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, "tokenizer_group cannot be None, " "make sure skip_tokenizer_init is False"
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(model_config)
self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(model_config)
self.model_executor = executor_class(
model=model, # add for spmd_gpu_executor
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import get_architecture_class_name
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype": str(model_config.dtype),
"tensor_parallel_size": parallel_config.tensor_parallel_size,
"block_size": cache_config.block_size,
"gpu_memory_utilization": cache_config.gpu_memory_utilization,
# Quantization
"quantization": model_config.quantization,
"kv_cache_dtype": str(cache_config.cache_dtype),
# Feature flags
"enable_lora": bool(lora_config),
"enable_prompt_adapter": bool(prompt_adapter_config),
"enable_prefix_caching": cache_config.enable_prefix_caching,
"enforce_eager": model_config.enforce_eager,
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
},
)
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs, ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(
scheduler_config,
cache_config,
lora_config,
parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id] if model_config.use_async_output_proc else None,
) for v_id in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger
self.stat_loggers = {
"logging":
LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len,
),
}
self.stat_loggers["prometheus"].info("cache_config", self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
)
# TODO(sgm): add for verl but we may not tokenizer in Rollout
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
init_kwargs = dict(enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None)
init_kwargs.update(tokenizer_init_kwargs)
return TokenizerGroup(tokenizer, **init_kwargs)
def init_cache_engine(self):
# TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache
# Re-capture CUDAGraph would be time-consuming
self.model_executor.init_cache_engine()
def free_cache_engine(self):
self.model_executor.free_cache_engine()
# NOTE(sgm): currently, we only support GPU executor
# The GPUExecutor remove the Ray dependency
@classmethod
def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]:
distributed_executor_backend = engine_config.parallel_config.distributed_executor_backend
# Initialize the cluster and specify the executor class.]
assert (engine_config.device_config.device_type == "cuda"
), "Currently, the vllm in verl only support running on GPU"
# print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client()
if engine_config.parallel_config.world_size == 1:
engine_config.load_config.load_format = "dummy_hf"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
return executor_class
@classmethod
def from_engine_args(
cls,
model,
tokenizer,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Initialize the cluster and specify the executor class.
assert (engine_config.device_config.device_type == "cuda"
), "Currently, the vllm in verl only support running on GPU"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
# Create the LLM engine.
engine = cls(
model,
tokenizer,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.model_executor.offload_model_weights()

View File

@@ -0,0 +1,308 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
from vllm.model_executor.models import ModelRegistry
# NOTE(shengguangming): replace the origin weight loader function in the class
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Parallel Linear weight loader."""
assert (param.size() == loaded_weight.size(
)), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(
param.size(), loaded_weight.size())
assert (param.data.dtype == loaded_weight.data.dtype
), "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
assert (param.data.dtype == loaded_weight.data.dtype
), "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
# TODO: check megatron
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
(
"input_layernorm",
"input_layernorm",
),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
(
"input_layernorm",
"input_layernorm",
),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# TODO: need to implement a general way to deal with prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
ColumnParallelLinear: parallel_weight_loader,
MergedColumnParallelLinear: parallel_weight_loader,
QKVParallelLinear: parallel_weight_loader,
RowParallelLinear: parallel_weight_loader,
VocabParallelEmbedding: parallel_weight_loader,
ParallelLMHead: parallel_weight_loader,
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
# "default_weight_loader": default_weight_loader
}
# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
# # setattr(layer_class, 'megatron_weight_loader', weight_loader)
# layer_class.weight_loader = weight_loader
__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
"GPT2LMHeadModel": gpt2_weight_loader,
"LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron
"LLaMAForCausalLM": llama_megatron_weight_loader,
"MistralForCausalLM": mistral_megatron_weight_loader,
}
# the actor model is .state_dict()
# Load megatron weights
def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def update_megatron_weight_loader():
for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
layer_class.weight_loader = weight_loader

View File

@@ -0,0 +1,338 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
"""Utilities for selecting and loading models."""
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from .config import LoadConfig, LoadFormat, ModelConfig
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
def get_model(
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig = None,
) -> nn.Module:
loader = get_model_loader(load_config)
if load_config.load_format.startswith("dummy"):
return loader.load_model(
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config,
)
else:
return loader.load_model(
actor_model=actor_model,
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config,
)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.AUTO:
update_megatron_weight_loader()
return MegatronLoader(load_config)
# NOTE(sgm): change the weight_loader function in runtime
if load_config.load_format == LoadFormat.MEGATRON:
update_megatron_weight_loader()
return MegatronLoader(load_config)
if load_config.load_format == LoadFormat.HF:
update_hf_weight_loader()
return HFLoader(load_config)
if load_config.load_format == LoadFormat.DTENSOR:
update_dtensor_weight_loader()
return DTensorLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_HF:
update_hf_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
update_megatron_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
update_dtensor_weight_loader()
return DummyModelLoader(load_config)
raise ValueError("load format not supported in verl: {}, only support {} and {}".format(
load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# initialize_dummy_weights(model)
return model.eval()
class MegatronLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_megatron_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class HFLoader(BaseModelLoader):
"""Model loader that can load the model weights from model's full params."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
if isinstance(actor_model, Dict):
return actor_model.items()
elif isinstance(actor_model, nn.Module):
return dict(actor_model.named_parameters()).items()
else:
raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}")
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
# with torch.device(device_config.device):
# NOTE(sgm): init the model in cpu
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
model.load_weights(self._get_weights_iterator(actor_model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class DTensorLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_dtensor_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
# FIXME(sgm): hack the _get_logits function in vllm v0.4.2
# as they use ray, the _get_logits result will only need to return to the driver node,
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
# all_gather is required (aligned with v0.2.6)
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
from vllm.model_executor.layers.logits_processor import LogitsProcessor
def logitsprocessor_init(
self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None,
) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super(LogitsProcessor, self).__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = False
LogitsProcessor.__init__ = logitsprocessor_init # use all_gather

View File

@@ -0,0 +1,182 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
import warnings
from enum import IntEnum
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
LoRAConfig,
ModelConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
)
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager
from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo
from vllm.worker.model_runner import ModelRunner
from .config import LoadConfig, ModelConfig
from .model_loader import get_model
logger = init_logger(__name__)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config,
kv_cache_dtype,
is_driver_worker=True, # a hack
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
input_registry=input_registry,
mm_registry=mm_registry,
)
# NOTE(sgm): add for verl
self.model = model # this will be replaced by get_model()
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(
self.model,
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = self.model.config.text_config.max_position_embeddings
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.device,
self.prompt_adapter_config,
)
self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2,
)
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend)

View File

@@ -0,0 +1,312 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
from typing import Optional
import torch
import torch.distributed
import vllm.distributed.parallel_state as ps
from vllm.distributed.parallel_state import (
get_pp_group,
get_world_group,
init_distributed_environment,
init_model_parallel_group,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Device mesh for using DTensor
_DEVICE_MESH = None
# Tensor model parallel group that the current rank belongs to.
_TP = None
# Pipeline model parallel group that the current rank belongs to.
_PP = None
# This method is for initializing the ParallelGroup when using HybridEngine
def initialize_parallel_state(
distributed_init_method: str = "env://",
backend: str = "nccl",
tensor_model_parallel_size: int = 1,
num_tp_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
rank = int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
if torch.distributed.get_world_size() > 1:
# NOTE: build a sepearate inference group with infer tp & micro dp
initialize_model_parallel_for_vllm(
tensor_model_parallel_size=tensor_model_parallel_size,
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp,
)
else:
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
return
assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
pp_world_size = get_pp_group().world_size
assert pp_world_size == pipeline_model_parallel_size, (
"pipeline parallel group already initialized, but of unexpected size: "
f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}")
# TODO(sgm): deviate from the v0.5.4, not pp now
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return ps._TP is not None
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def initialize_model_parallel_for_vllm(
tensor_model_parallel_size: int,
num_tensor_model_parallel_groups_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
) -> None:
pass
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert isinstance(tensor_model_parallel_size, int)
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
# Build the tensor model-parallel groups.
assert ps._TP is None, "tensor model parallel group is already initialized"
global _TP
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = torch.distributed.get_backend()
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
if num_tensor_model_parallel_groups_per_train_tp == 1:
# if tensor_model_parallel_size == train_tensor_parallel_size:
# using the same tp group as Megatron/vllm
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
# train_tp = train_tensor_parallel_size
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
NOTE: This method is a hack from the open-sourced version without
asertion of world_size = tp * pp
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)
# NOTE(sgm) we don't assert world_size == tp * pp
# DP is not managed by vllm but by the VeRL WorkerGroup
# if (world_size !=
# tensor_model_parallel_size * pipeline_model_parallel_size):
# raise RuntimeError(
# f"world_size ({world_size}) is not equal to "
# f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
# f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
rank = torch.distributed.get_rank()
global _TP
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
"""
Device mesh utilities
"""
def get_device_mesh():
assert _DEVICE_MESH is not None, "device mesh is not initialized"
return _DEVICE_MESH
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP.device_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

View File

@@ -0,0 +1,256 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
import os
import socket
from typing import Dict, List, Optional, Set, Tuple
import torch
from vllm.config import (
CacheConfig,
DeviceConfig,
LoRAConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from .config import LoadConfig, ModelConfig
logger = init_logger(__name__)
class SPMDGPUExecutor(ExecutorBase):
"""SPMD-based multi-GPU executor implementations."""
def __init__(
self,
model, # pytorch model itself or its parameter dict
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
distributed_init_method = initialize_cluster(parallel_config)
self._init_executor(model, distributed_init_method)
# TODO(sgm): verl not support speculative decode now
def _init_executor(self, model, distributed_init_method) -> None:
assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend."
# Create the parallel worker for each GPU.
self._init_workers_sp(model, distributed_init_method)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
print(f"local rank {local_rank}")
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ["NCCL_CUMEM_ENABLE"] = "0"
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
speculative_config=None,
prompt_adapter_config=self.speculative_config,
is_driver_worker=True,
model_runner_cls=None, # use the default one
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_device()
self.worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.determine_num_available_blocks()
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers."""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if torch.distributed.get_rank() == 0:
print(
f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB"
)
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
if torch.distributed.get_rank() == 0:
print(
f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB"
)
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
def init_cache_engine(self) -> None:
self.worker._init_cache_engine()
def free_cache_engine(self) -> None:
self.worker.free_cache_engine()
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
# NOTE(sgm):
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
# In vllm with ray, only the driver worker returns the sampling results.
return all_outputs
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request=lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id=lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def check_health(self) -> None:
# SPMDExecutor will always be healthy as long as
# it's running.
return
# NOTE(sgm) add for verl to pass the abstract class test, not used
from vllm.prompt_adapter.request import PromptAdapterRequest
def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.add_prompt_adapter(prompt_adapter_request)
def list_prompt_adapters(self) -> Set[int]:
return self.worker.list_prompt_adapters()
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.pin_lora(lora_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.pin_prompt_adapter(prompt_adapter_id)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.remove_prompt_adapter(prompt_adapter_id)
# NOTE(sgm): add for verl
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
Returns:
The `distributed_init_method` is the address for initializing the
distributed backend.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
distributed_init_method = "env://"
return distributed_init_method
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# TODO(sgm): not implemented async executor yet
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()

View File

@@ -0,0 +1,40 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import Optional
from transformers import PreTrainedTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import LRUCache
class TokenizerGroup(TokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

View File

@@ -0,0 +1,333 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (
CacheConfig,
DeviceConfig,
LoRAConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase
from vllm.worker.model_runner_base import ModelRunnerInputBase
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
from vllm.worker.worker_base import WorkerInput
from .config import LoadConfig, LoadFormat, ModelConfig
from .dtensor_weight_loaders import load_dtensor_weights
from .hf_weight_loader import load_hf_weights
from .megatron_weight_loaders import load_megatron_weights
from .model_runner import ModelRunner
from .parallel_state import ensure_model_parallel_initialized
class Worker(Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker # TODO: we don't need driver
# if parallel_config and is_driver_worker:
# assert rank % parallel_config.tensor_parallel_size == 0, \
# "Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_args = (
{} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or
(speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {
"return_hidden_states": True
})
# TODO(sgm): set correct model runner class
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model, # [VERL]: add for verl
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine] = None
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
# NOTE(sgm): [VERL] For offloading inference engine params
self.cpu_model = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
self.parallel_config.world_size = world_size
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
# NOTE(sgm) [VERL] use the remaining memory
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
# NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda")
num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda")
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _init_cache_engine(self):
if self.cache_engine is None and self.gpu_cache is None:
super()._init_cache_engine()
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
# NOTE(sgm): [VERL]: adapt from _execute_model_spmd()
def execute_model(self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)
# verl.worker.workerbase.WorkerBase
# swap cache
super().execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input,
self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None,
intermediate_tensors,
)
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict, load_format: str):
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
load_megatron_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.HF:
# full model state dict without no sharding
load_hf_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.DTENSOR:
load_dtensor_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device="cpu")
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = "env://",
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
ensure_model_parallel_initialized(
tensor_model_parallel_size=parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size,
)
# TODO(sgm): check whether need this
# if pynccl_utils.is_initialized():
# pynccl_world_size = pynccl_utils.get_world_size()
# if pynccl_world_size != parallel_config.world_size:
# raise RuntimeError(
# "pynccl is already initialized but the pynccl world "
# "size does not match parallel_config.world_size "
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
# elif parallel_config.world_size > 1:
# # NOTE(woosuk): We don't initialize pynccl process group when world size
# # is 1.
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
# pynccl_utils.init_process_group(
# group=get_tensor_model_parallel_cpu_group())
# # Initialize a custom fast all-reduce implementation.
# if not parallel_config.disable_custom_all_reduce:
# init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# if pynccl_utils.is_initialized():
# pynccl_utils.all_reduce(torch.zeros(1).cuda())

13
verl/trainer/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,6 @@
data:
path: /tmp/math_Qwen2-7B-Instruct.parquet
prompt_key: prompt
response_key: responses
data_source_key: data_source
reward_model_key: reward_model

View File

@@ -0,0 +1,35 @@
trainer:
nnodes: 1
n_gpus_per_node: 8
data:
path: ~/data/rlhf/math/test.parquet
prompt_key: prompt
n_samples: 5
output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
batch_size: 128
model:
path: ~/models/Qwen2-7B-Instruct
external_lib: null
rollout:
name: vllm
temperature: 1.0
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.7
prompt_length: 1536
response_length: 512
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
micro_batch_size: 256
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 1
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 8
# for hf rollout
do_sample: True

View File

@@ -0,0 +1,148 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: {}
enable_gradient_checkpointing: False
actor:
strategy: megatron # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
clip_ratio: 0.2
entropy_coeff: 0.001
ppo_epochs: 1
shuffle: True
optim:
lr: 1e-6
clip_grad: 1.0
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
ref:
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
param_offload: False
log_prob_micro_batch_size: 32
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # for xperf_gpt
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_megatron
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 2
# for hf rollout
do_sample: True
layer_name_map:
qkv_layer_name: qkv
gate_proj_layer_name: gate_up
# number of responses (i.e. num sample times)
n: 1
critic:
strategy: megatron
optim:
lr: 1e-5
clip_grad: 1.0
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
load_weight: True
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 2
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
kl_ctrl:
type: fixed
kl_coef: 0.001
reward_model:
enable: False
strategy: megatron
megatron:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug.
sequence_parallel: True
seed: 1
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
load_weight: True
param_offload: False
micro_batch_size: 64
max_length: null
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'wandb']
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: 2
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

View File

@@ -0,0 +1,177 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
train_data_num: null
val_data_num: null
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
max_start_length: 256
max_obs_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle_train_dataloader: True
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
state_masking: False
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
grad_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
log_prob_micro_batch_size: 128
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.95
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 128
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
n_agent: 1 # different here used for agent tasks only
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
micro_batch_size: 64
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
retriever:
url: "http://127.0.0.1:8000/retrieve"
topk: 3
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
no_think_rl: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
state_masking:
start_state_marker: "<information>"
end_state_marker: "</information>"
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: -1
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_turns: 10
do_search: true

View File

@@ -0,0 +1,42 @@
data:
train_batch_size: 256
micro_batch_size: 16 # this is also val batch size
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
prompt_key: question
response_key: answer
max_length: 1024
truncation: error
balance_dp_token: False
chat_template: null
model:
partial_pretrain: ~/models/gemma-1.1-7b-it
fsdp_config:
wrap_policy:
min_num_params: 0
cpu_offload: False
offload_params: False
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: False
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0
trainer:
default_local_dir: /tmp/sft_model
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
resume_path: null
project_name: gsm8k-sft
experiment_name: test
total_epochs: 4
total_training_steps: null
validate_before_training: False
logger: ['console']
seed: 1

View File

@@ -0,0 +1,435 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A lightweight one-file FSDP SFT Trainer
TODO(zhangchi.usc1992)
- Add calculation of mfu
- Add validation
"""
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import logging
import re
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
from tensordict import TensorDict
from torch.utils.data import DataLoader, DistributedSampler
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.dataset import SFTDataset
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.tracking import Tracking
from torch.distributed.device_mesh import DeviceMesh
import verl.utils.hdfs_io as hdfs_io
from verl.utils.debug import log_gpu_memory_usage
from peft import LoraConfig, TaskType, get_peft_model
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
def extract_step(path):
match = re.search(r'global_step_(\d+)', path)
if match:
return int(match.group(1))
return None
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class FSDPSFTTrainer(object):
def __init__(self, config, device_mesh: DeviceMesh):
self.config = config
self.device_mesh = device_mesh
# build tokenizer first
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
if self.config.data.chat_template is not None:
raise ValueError('Apply Chat template from config is not supported yet.')
# normalize dp size
self._normalize_config_bsz()
self._build_dataloader()
# build model
self._build_model_optimizer()
# TODO: add checkpoint manager
if self.device_mesh.get_rank() == 0:
print(self.config)
def _normalize_config_bsz(self):
dp_size = self.device_mesh.size()
if self.device_mesh.get_rank() == 0:
print(f'Normalize batch size by dp {dp_size}')
assert self.config.data.train_batch_size % dp_size == 0
assert self.config.data.micro_batch_size % dp_size == 0
self.config.data.train_batch_size //= dp_size
self.config.data.micro_batch_size //= dp_size
def _build_dataloader(self):
config = self.config
# build dataset
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)
# build dataloader
rank = self.device_mesh.get_rank()
world_size = self.device_mesh.size()
self.train_sampler = DistributedSampler(self.train_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
self.val_sampler = DistributedSampler(self.val_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=config.data.micro_batch_size,
sampler=self.val_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
def _build_model_optimizer(self):
# TODO (zhangchi.usc1992):
# 1. support pretrain from random weights
# 2. support init directly from sharded weights
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
log_gpu_memory_usage('Before model allocation', logger=logger)
trust_remote_code = self.config.model.trust_remote_code
# load config first
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
# This may be very large
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if self.config.model.get('lora_rank', 0) > 0:
self.model.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
if self.config.model.enable_gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
log_gpu_memory_usage('After model allocation', logger=logger)
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get('lora_rank', 0) > 0)
if self.device_mesh.get_rank() == 0:
print(auto_wrap_policy)
if not self.config.model.fsdp_config.cpu_offload:
cpu_offload = None
else:
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
self.fsdp_model = FSDP(module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False)
log_gpu_memory_usage('After FSDP wrapping', logger=logger)
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay)
log_gpu_memory_usage('After initialize optimizer', logger=logger)
steps_per_epoch = len(self.train_dataloader)
total_steps = steps_per_epoch * self.config.trainer.total_epochs
if self.device_mesh.get_rank() == 0:
print(
f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}'
)
num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio)
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps)
def _compute_loss(self, batch):
loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda()
labels = batch['input_ids'][:, 1:].cuda()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.fsdp_model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
position_ids=batch['position_ids'],
use_cache=False) # prevent model thinks it it generating
logits = output.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction='none')
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = loss * loss_mask
valid_token_this_rank = torch.sum(loss_mask)
if self.config.data.balance_dp_token:
torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks
dp_size = torch.distributed.get_world_size()
else:
dp_size = 1
loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp
return loss
def training_step(self, batch: TensorDict):
self.fsdp_model.train()
log_gpu_memory_usage('Before optimizer zero_grad', logger=logger)
self.optimizer.zero_grad()
log_gpu_memory_usage('After optimizer zero_grad', logger=logger)
micro_batches = batch.split(self.config.data.micro_batch_size)
n_micro_batches = len(micro_batches)
step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
loss.backward()
step_loss += loss.item()
self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
log_gpu_memory_usage('Before optimizer step', logger=logger)
self.optimizer.step()
log_gpu_memory_usage('After optimizer step', logger=logger)
self.lr_scheduler.step()
# reduce loss across dp ranks
lr = self.lr_scheduler.get_last_lr()[0]
log_gpu_memory_usage('After offload weights', logger=logger)
step_loss = torch.tensor(step_loss).cuda()
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
def validation_step(self, batch: TensorDict):
self.fsdp_model.eval()
with torch.no_grad():
loss = self._compute_loss(batch)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
return loss
def save_checkpoint(self, step):
# save checkpoint
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.fsdp_model.state_dict()
path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
# save huggingface model
if self.device_mesh.get_rank() == 0:
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
self.tokenizer.save_pretrained(path)
if self.config.trainer.default_hdfs_dir:
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
torch.distributed.barrier()
def fit(self):
rank = self.device_mesh.get_rank()
# TODO: add a unified tracking
if rank == 0:
tracking = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger)
global_step = 0
# compute the total training steps.
# the total training steps in SFT is mainly for early exit
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
# TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow.
if self.config.trainer.validate_before_training:
# validate before training
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
for epoch in range(self.config.trainer.total_epochs):
self.train_sampler.set_epoch(epoch=epoch)
for data in self.train_dataloader:
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
metric = self.training_step(data)
if rank == 0:
tracking.log(data=metric, step=global_step)
global_step += 1
# for early exit validation
if global_step >= self.total_training_steps:
# Perform final validation
val_losses = []
for val_data in self.val_dataloader:
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(val_data)
val_losses.append(val_loss)
if rank == 0:
avg_val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': avg_val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# Save final checkpoint
self.save_checkpoint(step=global_step)
return
# validation
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# save checkpoint
self.save_checkpoint(step=global_step)
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
import hydra
from torch.distributed.device_mesh import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
def main(config):
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
trainer.fit()
if __name__ == '__main__':
main()

69
verl/trainer/main_eval.py Normal file
View File

@@ -0,0 +1,69 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
"""
import hydra
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.reward_score import math, gsm8k
import pandas as pd
import numpy as np
def select_reward_fn(data_source):
if data_source == 'lighteval/MATH':
return math.compute_score
else:
raise NotImplementedError
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
def main(config):
local_path = copy_local_path_from_hdfs(config.data.path)
dataset = pd.read_parquet(local_path)
prompts = dataset[config.data.prompt_key]
responses = dataset[config.data.response_key]
data_sources = dataset[config.data.data_source_key]
reward_model_data = dataset[config.data.reward_model_key]
passes = 0
total = len(dataset)
for i in range(total):
response_lst = responses[i]
data_source = data_sources[i]
# select reward score based on data_source
prompt = prompts[i]
reward_data = reward_model_data[i]
reward_fn = select_reward_fn(data_source)
ground_truth = reward_data['ground_truth']
score_lst = []
for r in response_lst:
score = reward_fn(r, ground_truth)
score_lst.append(score)
max_score = np.max(score_lst)
if max_score == 1:
passes += 1
print(f'pass@5: {passes / total}')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,137 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""
import ray
import numpy as np
import hydra
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['TORCH_COMPILE_DISABLE'] = '1'
from verl.utils.model import compute_position_id_with_mask
import pandas as pd
from transformers import AutoTokenizer
from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@hydra.main(config_path='config', config_name='generation', version_base=None)
def main(config):
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
local_path = copy_local_path_from_hdfs(config.model.path)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
if config.rollout.temperature == 0.:
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
dataset = pd.read_parquet(config.data.path)
chat_lst = dataset[config.data.prompt_key].tolist()
chat_lst = [chat.tolist() for chat in chat_lst]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout')
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model()
total_samples = len(dataset)
# real_batch_size = data.batch['input_ids'].shape[0]
config_batch_size = config.data.batch_size
dp_size = wg.world_size // config.rollout.tensor_model_parallel_size
num_batch = (total_samples // config_batch_size) + 1
output_lst = [[] for _ in range(config.data.n_samples)]
for batch_idx in range(num_batch):
print(f'[{batch_idx+1}/{num_batch}] Start to process.')
batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size]
inputs = tokenizer.apply_chat_template(batch_chat_lst,
add_generation_prompt=True,
padding=True,
truncation=True,
max_length=config.rollout.prompt_length,
return_tensors='pt',
return_dict=True,
tokenize=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
position_ids = compute_position_id_with_mask(attention_mask)
batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
data = DataProto.from_dict(batch_dict)
real_batch_size = data.batch['input_ids'].shape[0]
if real_batch_size % dp_size != 0:
dummy_data_size = dp_size - real_batch_size % dp_size
dummy_data = data[:dummy_data_size]
data = DataProto.concat([data, dummy_data])
print(
f'dp_size {dp_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data'
)
batch_size = data.batch['input_ids'].shape[0]
assert batch_size % dp_size == 0, f'batch_size {batch_size} is not divisible by dp_size {dp_size}'
print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
# START TO GENERATE FOR n_samples TIMES
for i in range(config.data.n_samples):
output = wg.generate_sequences(data)
# remove dummy data
output = output[:real_batch_size]
output_text = tokenizer.batch_decode(output.batch['input_ids'][:, -config.rollout.response_length:],
skip_special_tokens=False)
# remove the padding
pad_token = tokenizer.pad_token
output_text_unpad = []
for text in output_text:
output_text_unpad.append(text.replace(pad_token, ''))
output_lst[i].extend(output_text_unpad)
# convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
output_lst = np.array(output_lst, dtype=object)
output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()
# add to the data frame
dataset[f'responses'] = output_lst
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
dataset.to_parquet(config.data.output_path)
return output_text
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More