Initial commit
This commit is contained in:
410
verl/single_controller/base/decorator.py
Normal file
410
verl/single_controller/base/decorator.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Tuple
|
||||
from types import FunctionType
|
||||
from verl.protocol import DataProtoFuture
|
||||
|
||||
# here we add a magic number of avoid user-defined function already have this attribute
|
||||
MAGIC_ATTR = 'attrs_3141562937'
|
||||
|
||||
|
||||
class Dispatch(Enum):
|
||||
RANK_ZERO = 0
|
||||
ONE_TO_ALL = 1
|
||||
ALL_TO_ALL = 2
|
||||
MEGATRON_COMPUTE = 3
|
||||
MEGATRON_PP_AS_DP = 4
|
||||
MEGATRON_PP_ONLY = 5
|
||||
MEGATRON_COMPUTE_PROTO = 6
|
||||
MEGATRON_PP_AS_DP_PROTO = 7
|
||||
DP_COMPUTE = 8
|
||||
DP_COMPUTE_PROTO = 9
|
||||
DP_COMPUTE_PROTO_WITH_FUNC = 10
|
||||
DP_COMPUTE_METRIC = 11
|
||||
|
||||
|
||||
class Execute(Enum):
|
||||
ALL = 0
|
||||
RANK_ZERO = 1
|
||||
|
||||
|
||||
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
|
||||
from verl.protocol import DataProto, DataProtoFuture
|
||||
splitted_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (DataProto, DataProtoFuture))
|
||||
splitted_args.append(arg.chunk(chunks=chunks))
|
||||
|
||||
splitted_kwargs = {}
|
||||
for key, val in kwargs.items():
|
||||
assert isinstance(val, (DataProto, DataProtoFuture))
|
||||
splitted_kwargs[key] = val.chunk(chunks=chunks)
|
||||
|
||||
return splitted_args, splitted_kwargs
|
||||
|
||||
|
||||
def dispatch_one_to_all(worker_group, *args, **kwargs):
|
||||
args = tuple([arg] * worker_group.world_size for arg in args)
|
||||
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def dispatch_all_to_all(worker_group, *args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def collect_all_to_all(worker_group, output):
|
||||
return output
|
||||
|
||||
|
||||
def dispatch_megatron_compute(worker_group, *args, **kwargs):
|
||||
"""
|
||||
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group,
|
||||
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
transformed_args.append(arg[local_dp_rank])
|
||||
all_args.append(transformed_args)
|
||||
all_args = tuple(all_args)
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
transformed_v.append(v[local_dp_rank])
|
||||
all_kwargs[k] = transformed_v
|
||||
return all_args, all_kwargs
|
||||
|
||||
|
||||
def collect_megatron_compute(worker_group, output):
|
||||
"""
|
||||
Only collect the data from the tp=0 and pp=last and every dp ranks
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_dp = []
|
||||
pp_size = worker_group.get_megatron_global_info().pp_size
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
|
||||
output_in_dp.append(output[global_rank])
|
||||
return output_in_dp
|
||||
|
||||
|
||||
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
|
||||
"""
|
||||
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
|
||||
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
|
||||
|
||||
|
||||
def _concat_data_proto_or_future(output: List):
|
||||
from verl.protocol import DataProto, DataProtoFuture
|
||||
import ray
|
||||
|
||||
# make sure all the elements in output has the same type
|
||||
for o in output:
|
||||
assert type(o) == type(output[0])
|
||||
|
||||
o = output[0]
|
||||
|
||||
if isinstance(o, DataProto):
|
||||
return DataProto.concat(output)
|
||||
elif isinstance(o, ray.ObjectRef):
|
||||
return DataProtoFuture.concat(output)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def collect_megatron_compute_data_proto(worker_group, output):
|
||||
"""
|
||||
Each output must be a DataProto. We concat the dim=0 of output
|
||||
"""
|
||||
from verl.protocol import DataProto
|
||||
import ray
|
||||
|
||||
output = collect_megatron_compute(worker_group, output)
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
|
||||
"""
|
||||
treat pp as dp.
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
pp_size = worker_group.pp_size
|
||||
dp_size = worker_group.dp_size
|
||||
|
||||
pp_dp_size = pp_size * dp_size
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||
# compute the rank in arg. Note that the order is dp then pp
|
||||
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
|
||||
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
|
||||
# dispatch: pp_allgther: collect:
|
||||
# dp 0 1 2 3 dp 0 1 2 3
|
||||
# pp +---------+ pp +-------------+
|
||||
# 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
|
||||
# 1 | B D F H | 1 | AB CD EF GH |
|
||||
# +---------+ +-------------+
|
||||
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||
|
||||
transformed_args.append(arg[arg_rank])
|
||||
all_args.append(transformed_args)
|
||||
all_args = tuple(all_args)
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_size, f'expect len(v)=={pp_dp_size}, got {len(v)}'
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||
# compute the rank in arg. Note that the order is dp then pp
|
||||
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||
transformed_v.append(v[arg_rank])
|
||||
all_kwargs[k] = transformed_v
|
||||
return all_args, all_kwargs
|
||||
|
||||
|
||||
def collect_megatron_pp_as_dp(worker_group, output):
|
||||
"""
|
||||
treat pp as dp. Only collect data on tp=0
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_dp = []
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
|
||||
output_in_dp.append(output[global_rank])
|
||||
return output_in_dp
|
||||
|
||||
|
||||
def collect_megatron_pp_only(worker_group, output):
|
||||
"""
|
||||
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_pp = []
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
|
||||
output_in_pp.append(output[global_rank])
|
||||
return output_in_pp
|
||||
|
||||
|
||||
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
pp_dp_size = worker_group.dp_size * worker_group.pp_size
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
|
||||
return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
|
||||
|
||||
|
||||
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
|
||||
from verl.protocol import DataProto
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
output = collect_megatron_pp_as_dp(worker_group, output)
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def dispatch_dp_compute(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def collect_dp_compute(worker_group, output):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
assert len(output) == worker_group.world_size
|
||||
return output
|
||||
|
||||
|
||||
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
|
||||
return splitted_args, splitted_kwargs
|
||||
|
||||
|
||||
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
|
||||
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
|
||||
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
|
||||
return splitted_args_with_func, splitted_kwargs
|
||||
|
||||
|
||||
def collect_dp_compute_data_proto(worker_group, output):
|
||||
from verl.protocol import DataProto
|
||||
import ray
|
||||
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
output = collect_dp_compute(worker_group, output)
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def get_predefined_dispatch_fn(dispatch_mode):
|
||||
predefined_dispatch_mode_fn = {
|
||||
Dispatch.ONE_TO_ALL: {
|
||||
'dispatch_fn': dispatch_one_to_all,
|
||||
'collect_fn': collect_all_to_all,
|
||||
},
|
||||
Dispatch.ALL_TO_ALL: {
|
||||
'dispatch_fn': dispatch_all_to_all,
|
||||
'collect_fn': collect_all_to_all,
|
||||
},
|
||||
Dispatch.MEGATRON_COMPUTE: {
|
||||
'dispatch_fn': dispatch_megatron_compute,
|
||||
'collect_fn': collect_megatron_compute,
|
||||
},
|
||||
Dispatch.MEGATRON_PP_AS_DP: {
|
||||
'dispatch_fn': dispatch_megatron_pp_as_dp,
|
||||
'collect_fn': collect_megatron_pp_as_dp,
|
||||
},
|
||||
Dispatch.MEGATRON_PP_ONLY: {
|
||||
'dispatch_fn': dispatch_one_to_all,
|
||||
'collect_fn': collect_megatron_pp_only
|
||||
},
|
||||
Dispatch.MEGATRON_COMPUTE_PROTO: {
|
||||
'dispatch_fn': dispatch_megatron_compute_data_proto,
|
||||
'collect_fn': collect_megatron_compute_data_proto
|
||||
},
|
||||
Dispatch.MEGATRON_PP_AS_DP_PROTO: {
|
||||
'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
|
||||
'collect_fn': collect_megatron_pp_as_dp_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE: {
|
||||
'dispatch_fn': dispatch_dp_compute,
|
||||
'collect_fn': collect_dp_compute
|
||||
},
|
||||
Dispatch.DP_COMPUTE_PROTO: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||
'collect_fn': collect_dp_compute_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
|
||||
'collect_fn': collect_dp_compute_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE_METRIC: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||
'collect_fn': collect_dp_compute
|
||||
}
|
||||
}
|
||||
return predefined_dispatch_mode_fn[dispatch_mode]
|
||||
|
||||
|
||||
def get_predefined_execute_fn(execute_mode):
|
||||
"""
|
||||
Note that here we only asks execute_all and execute_rank_zero to be implemented
|
||||
Leave the choice of how these two functions handle argument 'blocking' to users
|
||||
"""
|
||||
predefined_execute_mode_fn = {
|
||||
Execute.ALL: {
|
||||
'execute_fn_name': 'execute_all'
|
||||
},
|
||||
Execute.RANK_ZERO: {
|
||||
'execute_fn_name': 'execute_rank_zero'
|
||||
}
|
||||
}
|
||||
return predefined_execute_mode_fn[execute_mode]
|
||||
|
||||
|
||||
def _check_dispatch_mode(dispatch_mode):
|
||||
assert isinstance(dispatch_mode,
|
||||
(Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}'
|
||||
if isinstance(dispatch_mode, Dict):
|
||||
necessary_keys = ['dispatch_fn', 'collect_fn']
|
||||
for key in necessary_keys:
|
||||
assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary'
|
||||
|
||||
|
||||
def _check_execute_mode(execute_mode):
|
||||
assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}'
|
||||
|
||||
|
||||
def _materialize_futures(*args, **kwargs):
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, DataProtoFuture):
|
||||
arg = arg.get()
|
||||
# add more type to materialize
|
||||
new_args.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, DataProtoFuture):
|
||||
kwargs[k] = v.get()
|
||||
|
||||
new_args = tuple(new_args)
|
||||
return new_args, kwargs
|
||||
|
||||
|
||||
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
|
||||
_check_dispatch_mode(dispatch_mode=dispatch_mode)
|
||||
_check_execute_mode(execute_mode=execute_mode)
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
if materialize_futures:
|
||||
args, kwargs = _materialize_futures(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
|
||||
setattr(inner, MAGIC_ATTR, attrs)
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user