Initial commit
This commit is contained in:
20
verl/single_controller/__init__.py
Normal file
20
verl/single_controller/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
||||
|
||||
with open(os.path.join(version_folder, 'version/version')) as f:
|
||||
__version__ = f.read().strip()
|
||||
16
verl/single_controller/base/__init__.py
Normal file
16
verl/single_controller/base/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .worker import Worker
|
||||
from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool
|
||||
410
verl/single_controller/base/decorator.py
Normal file
410
verl/single_controller/base/decorator.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Tuple
|
||||
from types import FunctionType
|
||||
from verl.protocol import DataProtoFuture
|
||||
|
||||
# here we add a magic number of avoid user-defined function already have this attribute
|
||||
MAGIC_ATTR = 'attrs_3141562937'
|
||||
|
||||
|
||||
class Dispatch(Enum):
|
||||
RANK_ZERO = 0
|
||||
ONE_TO_ALL = 1
|
||||
ALL_TO_ALL = 2
|
||||
MEGATRON_COMPUTE = 3
|
||||
MEGATRON_PP_AS_DP = 4
|
||||
MEGATRON_PP_ONLY = 5
|
||||
MEGATRON_COMPUTE_PROTO = 6
|
||||
MEGATRON_PP_AS_DP_PROTO = 7
|
||||
DP_COMPUTE = 8
|
||||
DP_COMPUTE_PROTO = 9
|
||||
DP_COMPUTE_PROTO_WITH_FUNC = 10
|
||||
DP_COMPUTE_METRIC = 11
|
||||
|
||||
|
||||
class Execute(Enum):
|
||||
ALL = 0
|
||||
RANK_ZERO = 1
|
||||
|
||||
|
||||
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
|
||||
from verl.protocol import DataProto, DataProtoFuture
|
||||
splitted_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (DataProto, DataProtoFuture))
|
||||
splitted_args.append(arg.chunk(chunks=chunks))
|
||||
|
||||
splitted_kwargs = {}
|
||||
for key, val in kwargs.items():
|
||||
assert isinstance(val, (DataProto, DataProtoFuture))
|
||||
splitted_kwargs[key] = val.chunk(chunks=chunks)
|
||||
|
||||
return splitted_args, splitted_kwargs
|
||||
|
||||
|
||||
def dispatch_one_to_all(worker_group, *args, **kwargs):
|
||||
args = tuple([arg] * worker_group.world_size for arg in args)
|
||||
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def dispatch_all_to_all(worker_group, *args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def collect_all_to_all(worker_group, output):
|
||||
return output
|
||||
|
||||
|
||||
def dispatch_megatron_compute(worker_group, *args, **kwargs):
|
||||
"""
|
||||
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group,
|
||||
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
transformed_args.append(arg[local_dp_rank])
|
||||
all_args.append(transformed_args)
|
||||
all_args = tuple(all_args)
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
transformed_v.append(v[local_dp_rank])
|
||||
all_kwargs[k] = transformed_v
|
||||
return all_args, all_kwargs
|
||||
|
||||
|
||||
def collect_megatron_compute(worker_group, output):
|
||||
"""
|
||||
Only collect the data from the tp=0 and pp=last and every dp ranks
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_dp = []
|
||||
pp_size = worker_group.get_megatron_global_info().pp_size
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
|
||||
output_in_dp.append(output[global_rank])
|
||||
return output_in_dp
|
||||
|
||||
|
||||
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
|
||||
"""
|
||||
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
|
||||
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
|
||||
|
||||
|
||||
def _concat_data_proto_or_future(output: List):
|
||||
from verl.protocol import DataProto, DataProtoFuture
|
||||
import ray
|
||||
|
||||
# make sure all the elements in output has the same type
|
||||
for o in output:
|
||||
assert type(o) == type(output[0])
|
||||
|
||||
o = output[0]
|
||||
|
||||
if isinstance(o, DataProto):
|
||||
return DataProto.concat(output)
|
||||
elif isinstance(o, ray.ObjectRef):
|
||||
return DataProtoFuture.concat(output)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def collect_megatron_compute_data_proto(worker_group, output):
|
||||
"""
|
||||
Each output must be a DataProto. We concat the dim=0 of output
|
||||
"""
|
||||
from verl.protocol import DataProto
|
||||
import ray
|
||||
|
||||
output = collect_megatron_compute(worker_group, output)
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
|
||||
"""
|
||||
treat pp as dp.
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
pp_size = worker_group.pp_size
|
||||
dp_size = worker_group.dp_size
|
||||
|
||||
pp_dp_size = pp_size * dp_size
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||
# compute the rank in arg. Note that the order is dp then pp
|
||||
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
|
||||
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
|
||||
# dispatch: pp_allgther: collect:
|
||||
# dp 0 1 2 3 dp 0 1 2 3
|
||||
# pp +---------+ pp +-------------+
|
||||
# 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
|
||||
# 1 | B D F H | 1 | AB CD EF GH |
|
||||
# +---------+ +-------------+
|
||||
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||
|
||||
transformed_args.append(arg[arg_rank])
|
||||
all_args.append(transformed_args)
|
||||
all_args = tuple(all_args)
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_size, f'expect len(v)=={pp_dp_size}, got {len(v)}'
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||
# compute the rank in arg. Note that the order is dp then pp
|
||||
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||
transformed_v.append(v[arg_rank])
|
||||
all_kwargs[k] = transformed_v
|
||||
return all_args, all_kwargs
|
||||
|
||||
|
||||
def collect_megatron_pp_as_dp(worker_group, output):
|
||||
"""
|
||||
treat pp as dp. Only collect data on tp=0
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_dp = []
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
|
||||
output_in_dp.append(output[global_rank])
|
||||
return output_in_dp
|
||||
|
||||
|
||||
def collect_megatron_pp_only(worker_group, output):
|
||||
"""
|
||||
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
|
||||
"""
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
output_in_pp = []
|
||||
for global_rank in range(worker_group.world_size):
|
||||
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
|
||||
output_in_pp.append(output[global_rank])
|
||||
return output_in_pp
|
||||
|
||||
|
||||
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
pp_dp_size = worker_group.dp_size * worker_group.pp_size
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
|
||||
return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
|
||||
|
||||
|
||||
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
|
||||
from verl.protocol import DataProto
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||
|
||||
output = collect_megatron_pp_as_dp(worker_group, output)
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def dispatch_dp_compute(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def collect_dp_compute(worker_group, output):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
assert len(output) == worker_group.world_size
|
||||
return output
|
||||
|
||||
|
||||
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
|
||||
return splitted_args, splitted_kwargs
|
||||
|
||||
|
||||
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
|
||||
from verl.single_controller.base.worker_group import WorkerGroup
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
|
||||
|
||||
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
|
||||
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
|
||||
return splitted_args_with_func, splitted_kwargs
|
||||
|
||||
|
||||
def collect_dp_compute_data_proto(worker_group, output):
|
||||
from verl.protocol import DataProto
|
||||
import ray
|
||||
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
output = collect_dp_compute(worker_group, output)
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
|
||||
def get_predefined_dispatch_fn(dispatch_mode):
|
||||
predefined_dispatch_mode_fn = {
|
||||
Dispatch.ONE_TO_ALL: {
|
||||
'dispatch_fn': dispatch_one_to_all,
|
||||
'collect_fn': collect_all_to_all,
|
||||
},
|
||||
Dispatch.ALL_TO_ALL: {
|
||||
'dispatch_fn': dispatch_all_to_all,
|
||||
'collect_fn': collect_all_to_all,
|
||||
},
|
||||
Dispatch.MEGATRON_COMPUTE: {
|
||||
'dispatch_fn': dispatch_megatron_compute,
|
||||
'collect_fn': collect_megatron_compute,
|
||||
},
|
||||
Dispatch.MEGATRON_PP_AS_DP: {
|
||||
'dispatch_fn': dispatch_megatron_pp_as_dp,
|
||||
'collect_fn': collect_megatron_pp_as_dp,
|
||||
},
|
||||
Dispatch.MEGATRON_PP_ONLY: {
|
||||
'dispatch_fn': dispatch_one_to_all,
|
||||
'collect_fn': collect_megatron_pp_only
|
||||
},
|
||||
Dispatch.MEGATRON_COMPUTE_PROTO: {
|
||||
'dispatch_fn': dispatch_megatron_compute_data_proto,
|
||||
'collect_fn': collect_megatron_compute_data_proto
|
||||
},
|
||||
Dispatch.MEGATRON_PP_AS_DP_PROTO: {
|
||||
'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
|
||||
'collect_fn': collect_megatron_pp_as_dp_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE: {
|
||||
'dispatch_fn': dispatch_dp_compute,
|
||||
'collect_fn': collect_dp_compute
|
||||
},
|
||||
Dispatch.DP_COMPUTE_PROTO: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||
'collect_fn': collect_dp_compute_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
|
||||
'collect_fn': collect_dp_compute_data_proto
|
||||
},
|
||||
Dispatch.DP_COMPUTE_METRIC: {
|
||||
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||
'collect_fn': collect_dp_compute
|
||||
}
|
||||
}
|
||||
return predefined_dispatch_mode_fn[dispatch_mode]
|
||||
|
||||
|
||||
def get_predefined_execute_fn(execute_mode):
|
||||
"""
|
||||
Note that here we only asks execute_all and execute_rank_zero to be implemented
|
||||
Leave the choice of how these two functions handle argument 'blocking' to users
|
||||
"""
|
||||
predefined_execute_mode_fn = {
|
||||
Execute.ALL: {
|
||||
'execute_fn_name': 'execute_all'
|
||||
},
|
||||
Execute.RANK_ZERO: {
|
||||
'execute_fn_name': 'execute_rank_zero'
|
||||
}
|
||||
}
|
||||
return predefined_execute_mode_fn[execute_mode]
|
||||
|
||||
|
||||
def _check_dispatch_mode(dispatch_mode):
|
||||
assert isinstance(dispatch_mode,
|
||||
(Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}'
|
||||
if isinstance(dispatch_mode, Dict):
|
||||
necessary_keys = ['dispatch_fn', 'collect_fn']
|
||||
for key in necessary_keys:
|
||||
assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary'
|
||||
|
||||
|
||||
def _check_execute_mode(execute_mode):
|
||||
assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}'
|
||||
|
||||
|
||||
def _materialize_futures(*args, **kwargs):
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, DataProtoFuture):
|
||||
arg = arg.get()
|
||||
# add more type to materialize
|
||||
new_args.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, DataProtoFuture):
|
||||
kwargs[k] = v.get()
|
||||
|
||||
new_args = tuple(new_args)
|
||||
return new_args, kwargs
|
||||
|
||||
|
||||
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
|
||||
_check_dispatch_mode(dispatch_mode=dispatch_mode)
|
||||
_check_execute_mode(execute_mode=execute_mode)
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
if materialize_futures:
|
||||
args, kwargs = _materialize_futures(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
|
||||
setattr(inner, MAGIC_ATTR, attrs)
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
13
verl/single_controller/base/megatron/__init__.py
Normal file
13
verl/single_controller/base/megatron/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
39
verl/single_controller/base/megatron/worker.py
Normal file
39
verl/single_controller/base/megatron/worker.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
|
||||
|
||||
|
||||
class MegatronWorker(Worker):
|
||||
|
||||
def __init__(self, cuda_visible_devices=None) -> None:
|
||||
super().__init__(cuda_visible_devices)
|
||||
|
||||
def get_megatron_global_info(self):
|
||||
from megatron.core import parallel_state as mpu
|
||||
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||
dp_size = mpu.get_data_parallel_world_size()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
|
||||
return info
|
||||
|
||||
def get_megatron_rank_info(self):
|
||||
from megatron.core import parallel_state as mpu
|
||||
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
dp_rank = mpu.get_data_parallel_rank()
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
|
||||
return info
|
||||
51
verl/single_controller/base/megatron/worker_group.py
Normal file
51
verl/single_controller/base/megatron/worker_group.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .worker import DistRankInfo, DistGlobalInfo
|
||||
from verl.single_controller.base import ResourcePool, WorkerGroup
|
||||
|
||||
|
||||
class MegatronWorkerGroup(WorkerGroup):
|
||||
|
||||
def __init__(self, resource_pool: ResourcePool, **kwargs):
|
||||
super().__init__(resource_pool=resource_pool, **kwargs)
|
||||
self._megatron_rank_info = None
|
||||
self._megatron_global_info: DistGlobalInfo = None
|
||||
|
||||
def init_megatron(self, default_megatron_kwargs: Dict = None):
|
||||
raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
|
||||
|
||||
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
|
||||
assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
|
||||
return self._megatron_rank_info[rank]
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||
return self._megatron_global_info.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||
return self._megatron_global_info.dp_size
|
||||
|
||||
@property
|
||||
def pp_size(self):
|
||||
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||
return self._megatron_global_info.pp_size
|
||||
|
||||
def get_megatron_global_info(self):
|
||||
return self._megatron_global_info
|
||||
13
verl/single_controller/base/register_center/__init__.py
Normal file
13
verl/single_controller/base/register_center/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
29
verl/single_controller/base/register_center/ray.py
Normal file
29
verl/single_controller/base/register_center/ray.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkerGroupRegisterCenter:
|
||||
|
||||
def __init__(self, rank_zero_info):
|
||||
self.rank_zero_info = rank_zero_info
|
||||
|
||||
def get_rank_zero_info(self):
|
||||
return self.rank_zero_info
|
||||
|
||||
|
||||
def create_worker_group_register_center(name, info):
|
||||
return WorkerGroupRegisterCenter.options(name=name).remote(info)
|
||||
186
verl/single_controller/base/worker.py
Normal file
186
verl/single_controller/base/worker.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
the class for Worker
|
||||
"""
|
||||
import os
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from verl.single_controller.base.decorator import register, Dispatch, Execute
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistRankInfo:
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
pp_rank: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistGlobalInfo:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
pp_size: int
|
||||
|
||||
|
||||
class WorkerHelper:
|
||||
|
||||
def _get_node_ip(self):
|
||||
|
||||
def get_node_ip_by_sdk():
|
||||
if os.getenv("WG_BACKEND", None) == "ray":
|
||||
import ray
|
||||
return ray._private.services.get_node_ip_address()
|
||||
elif os.getenv("WG_BACKEND", None) == "torch_rpc":
|
||||
from verl.single_controller.torchrpc.k8s_client import get_ip_addr
|
||||
return get_ip_addr()
|
||||
return None
|
||||
|
||||
host_ipv4 = os.getenv("MY_HOST_IP", None)
|
||||
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
|
||||
host_ip_by_env = host_ipv4 or host_ipv6
|
||||
host_ip_by_sdk = get_node_ip_by_sdk()
|
||||
|
||||
host_ip = host_ip_by_env or host_ip_by_sdk
|
||||
return host_ip
|
||||
|
||||
def _get_free_port(self):
|
||||
with socket.socket() as sock:
|
||||
sock.bind(('', 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
def get_availale_master_addr_port(self):
|
||||
return self._get_node_ip(), str(self._get_free_port())
|
||||
|
||||
def _get_pid(self):
|
||||
return
|
||||
|
||||
|
||||
class WorkerMeta:
|
||||
keys = [
|
||||
"WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"
|
||||
]
|
||||
|
||||
def __init__(self, store) -> None:
|
||||
self._store = store
|
||||
|
||||
def to_dict(self):
|
||||
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
|
||||
|
||||
|
||||
# we assume that in each WorkerGroup, there is a Master Worker
|
||||
class Worker(WorkerHelper):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
instance = super().__new__(cls)
|
||||
|
||||
# note that here we use int to distinguish
|
||||
disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0))
|
||||
if disable_worker_init:
|
||||
return instance
|
||||
|
||||
rank = os.environ.get("RANK", None)
|
||||
worker_group_prefix = os.environ.get("WG_PREFIX", None)
|
||||
|
||||
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
|
||||
if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__:
|
||||
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
|
||||
|
||||
return instance
|
||||
|
||||
def _configure_before_init(self, register_center_name: str, rank: int):
|
||||
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
|
||||
|
||||
if rank == 0:
|
||||
master_addr, master_port = self.get_availale_master_addr_port()
|
||||
rank_zero_info = {
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
}
|
||||
|
||||
if os.getenv("WG_BACKEND", None) == "ray":
|
||||
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
|
||||
self.register_center = create_worker_group_register_center(name=register_center_name,
|
||||
info=rank_zero_info)
|
||||
|
||||
os.environ.update(rank_zero_info)
|
||||
|
||||
def __init__(self, cuda_visible_devices=None) -> None:
|
||||
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
|
||||
import os
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
rank = int(os.environ['RANK'])
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
|
||||
master_addr = os.environ["MASTER_ADDR"]
|
||||
master_port = os.environ["MASTER_PORT"]
|
||||
|
||||
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
store = {
|
||||
'_world_size': world_size,
|
||||
'_rank': rank,
|
||||
'_local_world_size': local_world_size,
|
||||
'_local_rank': local_rank,
|
||||
'_master_addr': master_addr,
|
||||
'_master_port': master_port
|
||||
}
|
||||
if cuda_visible_devices is not None:
|
||||
store['_cuda_visible_devices'] = cuda_visible_devices
|
||||
|
||||
meta = WorkerMeta(store=store)
|
||||
self._configure_with_meta(meta=meta)
|
||||
|
||||
def _configure_with_meta(self, meta: WorkerMeta):
|
||||
"""
|
||||
This function should only be called inside by WorkerGroup
|
||||
"""
|
||||
assert isinstance(meta, WorkerMeta)
|
||||
self.__dict__.update(meta.to_dict()) # this is hacky
|
||||
# print(f"__dict__: {self.__dict__}")
|
||||
for key in WorkerMeta.keys:
|
||||
val = self.__dict__.get(f"_{key.lower()}", None)
|
||||
if val is not None:
|
||||
# print(f"set {key} to {val}")
|
||||
os.environ[key] = str(val)
|
||||
os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace(
|
||||
"]", "") if self._master_addr else ""
|
||||
|
||||
def get_master_addr_port(self):
|
||||
return self._master_addr, self._master_port
|
||||
|
||||
def get_cuda_visible_devices(self):
|
||||
import os
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
|
||||
return cuda_visible_devices
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
|
||||
def execute_with_func_generator(self, func, *args, **kwargs):
|
||||
ret_proto = func(self, *args, **kwargs)
|
||||
return ret_proto
|
||||
|
||||
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
|
||||
def execute_func_rank_zero(self, func, *args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
196
verl/single_controller/base/worker_group.py
Normal file
196
verl/single_controller/base/worker_group.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
the class of WorkerGroup
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Any, Callable, Dict
|
||||
|
||||
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
||||
|
||||
|
||||
class ResourcePool:
|
||||
|
||||
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
|
||||
if process_on_nodes is None:
|
||||
process_on_nodes = []
|
||||
self._store = process_on_nodes
|
||||
self.max_collocate_count = max_collocate_count
|
||||
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
|
||||
|
||||
def add_node(self, process_count):
|
||||
self._store.append(process_count)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return sum(self._store)
|
||||
|
||||
def __call__(self) -> Any:
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
return self._store
|
||||
|
||||
def local_world_size_list(self) -> List[int]:
|
||||
nested_local_world_size_list = [
|
||||
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
|
||||
]
|
||||
return [item for row in nested_local_world_size_list for item in row]
|
||||
|
||||
def local_rank_list(self) -> List[int]:
|
||||
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
|
||||
return [item for row in nested_local_rank_list for item in row]
|
||||
|
||||
|
||||
class ClassWithInitArgs:
|
||||
"""
|
||||
This class stores a class constructor and the args/kwargs to construct the class.
|
||||
It is used to instantiate the remote class.
|
||||
"""
|
||||
|
||||
def __init__(self, cls, *args, **kwargs) -> None:
|
||||
self.cls = cls
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
# def add_arg(self, arg):
|
||||
# self.args += (arg,)
|
||||
|
||||
# def add_kwarg(self, key, value):
|
||||
# self.kwargs[key] = value
|
||||
|
||||
def __call__(self) -> Any:
|
||||
return self.cls(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
||||
import time
|
||||
while True:
|
||||
for worker in workers:
|
||||
if not is_alive(worker):
|
||||
logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
|
||||
signal.raise_signal(signal.SIGABRT)
|
||||
time.sleep(gap_time)
|
||||
|
||||
|
||||
class WorkerGroup:
|
||||
|
||||
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
|
||||
self._is_init_with_detached_workers = True if resource_pool is None else False
|
||||
|
||||
if resource_pool is not None:
|
||||
# handle the case when WorkGroup is attached to an existing one
|
||||
self._procecss_dispatch_config = resource_pool()
|
||||
else:
|
||||
self._procecss_dispatch_config = None
|
||||
|
||||
self._workers = []
|
||||
self._worker_names = []
|
||||
|
||||
self._master_addr = None
|
||||
self._master_port = None
|
||||
|
||||
self._checker_thread: threading.Thread = None
|
||||
|
||||
def _is_worker_alive(self, worker):
|
||||
raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
|
||||
|
||||
def _block_until_all_workers_alive(self) -> None:
|
||||
while True:
|
||||
all_state = [self._is_worker_alive(worker) for worker in self._workers]
|
||||
if False in all_state:
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
|
||||
# before starting checking worker aliveness, make sure all workers are already alive
|
||||
self._block_until_all_workers_alive()
|
||||
|
||||
self._checker_thread = threading.Thread(target=check_workers_alive,
|
||||
args=(self._workers, self._is_worker_alive, every_n_seconds))
|
||||
self._checker_thread.start()
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return len(self._workers)
|
||||
|
||||
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
|
||||
# MegatronWorkerGroup, XperfWorkerGroup should skip
|
||||
|
||||
def _bind_worker_method(self, user_defined_cls, func_generator):
|
||||
"""
|
||||
Bind the worker method to the WorkerGroup
|
||||
"""
|
||||
|
||||
for method_name in dir(user_defined_cls):
|
||||
|
||||
try:
|
||||
method = getattr(user_defined_cls, method_name)
|
||||
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
||||
except Exception as e:
|
||||
# if it is a property, it will fail because Class doesn't have instance property
|
||||
continue
|
||||
|
||||
if hasattr(method, MAGIC_ATTR):
|
||||
# this method is decorated by register
|
||||
attribute = getattr(method, MAGIC_ATTR)
|
||||
assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}'
|
||||
assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key'
|
||||
|
||||
dispatch_mode = attribute['dispatch_mode']
|
||||
execute_mode = attribute['execute_mode']
|
||||
blocking = attribute['blocking']
|
||||
|
||||
# get dispatch fn
|
||||
if isinstance(dispatch_mode, Dispatch):
|
||||
# get default dispatch fn
|
||||
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
|
||||
dispatch_fn = fn['dispatch_fn']
|
||||
collect_fn = fn['collect_fn']
|
||||
else:
|
||||
assert isinstance(dispatch_mode, dict)
|
||||
assert 'dispatch_fn' in dispatch_mode
|
||||
assert 'collect_fn' in dispatch_mode
|
||||
dispatch_fn = dispatch_mode['dispatch_fn']
|
||||
collect_fn = dispatch_mode['collect_fn']
|
||||
|
||||
# get execute_fn_name
|
||||
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
|
||||
wg_execute_fn_name = execute_mode['execute_fn_name']
|
||||
|
||||
# get execute_fn from string
|
||||
try:
|
||||
execute_fn = getattr(self, wg_execute_fn_name)
|
||||
assert callable(execute_fn), 'execute_fn must be callable'
|
||||
except Exception as e:
|
||||
print(f'execute_fn {wg_execute_fn_name} is invalid')
|
||||
raise
|
||||
|
||||
# bind a new method to the RayWorkerGroup
|
||||
func = func_generator(self,
|
||||
method_name,
|
||||
dispatch_fn=dispatch_fn,
|
||||
collect_fn=collect_fn,
|
||||
execute_fn=execute_fn,
|
||||
blocking=blocking)
|
||||
|
||||
try:
|
||||
setattr(self, method_name, func)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Fail to set method_name {method_name}')
|
||||
16
verl/single_controller/ray/__init__.py
Normal file
16
verl/single_controller/ray/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
|
||||
from .megatron import (MegatronRayWorkerGroup, DistRankInfo, DistGlobalInfo)
|
||||
459
verl/single_controller/ray/base.py
Normal file
459
verl/single_controller/ray/base.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Tuple
|
||||
|
||||
import ray
|
||||
from ray.util import list_named_actors
|
||||
from ray.util.placement_group import placement_group, PlacementGroup
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
|
||||
from ray.experimental.state.api import get_actor
|
||||
|
||||
from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
|
||||
|
||||
__all__ = ['Worker']
|
||||
|
||||
|
||||
def get_random_string(length: int) -> str:
|
||||
import random
|
||||
import string
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(letters_digits) for _ in range(length))
|
||||
|
||||
|
||||
def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
|
||||
|
||||
def func(*args, **kwargs):
|
||||
args, kwargs = dispatch_fn(self, *args, **kwargs)
|
||||
output = execute_fn(method_name, *args, **kwargs)
|
||||
if blocking:
|
||||
output = ray.get(output)
|
||||
output = collect_fn(self, output)
|
||||
return output
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class RayResourcePool(ResourcePool):
|
||||
|
||||
def __init__(self,
|
||||
process_on_nodes: List[int] = None,
|
||||
use_gpu: bool = True,
|
||||
name_prefix: str = "",
|
||||
max_colocate_count: int = 5,
|
||||
detached=False) -> None:
|
||||
super().__init__(process_on_nodes, max_colocate_count)
|
||||
self.use_gpu = use_gpu
|
||||
# print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}")
|
||||
self.name_prefix = name_prefix
|
||||
self.pgs = None
|
||||
self.detached = detached
|
||||
|
||||
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
|
||||
if self.pgs is not None:
|
||||
return self.pgs
|
||||
|
||||
pg_name_prefix = name if name else \
|
||||
f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
|
||||
# print(f"pg_name_prefix = {pg_name_prefix}")
|
||||
pg_scheme = [[{
|
||||
"CPU": self.max_collocate_count,
|
||||
"GPU": 1
|
||||
} if self.use_gpu else {
|
||||
"CPU": self.max_collocate_count
|
||||
} for _ in range(process_count)] for process_count in self._store]
|
||||
|
||||
lifetime = 'detached' if self.detached else None
|
||||
|
||||
pgs = [
|
||||
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
|
||||
for idx, bundles in enumerate(pg_scheme)
|
||||
]
|
||||
|
||||
ray.get([pg.ready() for pg in pgs])
|
||||
|
||||
self.pgs = pgs
|
||||
return pgs
|
||||
|
||||
|
||||
def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str],
|
||||
resource_pool: RayResourcePool) -> List:
|
||||
|
||||
src_pgs = [
|
||||
pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups()
|
||||
if role_name in src_role_names
|
||||
]
|
||||
|
||||
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
|
||||
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
|
||||
|
||||
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
|
||||
searching_idx = 0
|
||||
for request_process, original_idx in sorted_process_on_nodes:
|
||||
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
|
||||
assert request_process <= sorted_src_pgs[searching_idx].bundle_count, \
|
||||
f"requesting {request_process} processes, bundle count cannot satisfy"
|
||||
unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))
|
||||
searching_idx += 1
|
||||
|
||||
return [pg for _, pg in sorted(unsorted_pgs)]
|
||||
|
||||
|
||||
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
|
||||
assert rp1.use_gpu == rp2.use_gpu, 'Both RayResourcePool must either use_gpu or not'
|
||||
assert rp1.max_collocate_count == rp2.max_collocate_count, 'Both RayResourcePool must has the same max_collocate_count'
|
||||
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, 'Both RayResourcePool must has the same n_gpus_per_node'
|
||||
assert rp1.detached == rp2.detached, 'Detached ResourcePool cannot be merged with non-detached ResourcePool'
|
||||
|
||||
new_store = rp1.store + rp2.store
|
||||
|
||||
merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}")
|
||||
merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class RayClassWithInitArgs(ClassWithInitArgs):
|
||||
|
||||
def __init__(self, cls, *args, **kwargs) -> None:
|
||||
# self._options = kwargs.pop('options', dict())
|
||||
super().__init__(cls, *args, **kwargs)
|
||||
self._options = {}
|
||||
self._additional_resource = {}
|
||||
|
||||
def set_additional_resource(self, additional_resource):
|
||||
self._additional_resource = additional_resource
|
||||
|
||||
def update_options(self, options: Dict):
|
||||
self._options.update(options)
|
||||
|
||||
def __call__(self,
|
||||
placement_group,
|
||||
placement_group_bundle_idx,
|
||||
use_gpu: bool = True,
|
||||
num_gpus=1,
|
||||
sharing_with=None) -> Any:
|
||||
if sharing_with is not None:
|
||||
target_node_id = ray.get(sharing_with.get_node_id.remote())
|
||||
cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
|
||||
options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
|
||||
return self.cls.options(**options).remote(*self.args,
|
||||
cuda_visible_devices=cuda_visible_devices,
|
||||
**self.kwargs)
|
||||
|
||||
options = {
|
||||
"scheduling_strategy":
|
||||
PlacementGroupSchedulingStrategy(placement_group=placement_group,
|
||||
placement_group_bundle_index=placement_group_bundle_idx)
|
||||
}
|
||||
options.update(self._options)
|
||||
|
||||
if use_gpu:
|
||||
options["num_gpus"] = num_gpus
|
||||
|
||||
if len(self._additional_resource) > 1:
|
||||
for k, v in self._additional_resource.items():
|
||||
options[k] = v
|
||||
|
||||
# print("cls:", self.cls)
|
||||
# print("args: ", self.args)
|
||||
# print("kwargs: ", self.kwargs)
|
||||
return self.cls.options(**options).remote(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
class RayWorkerGroup(WorkerGroup):
|
||||
|
||||
def __init__(self,
|
||||
resource_pool: RayResourcePool = None,
|
||||
ray_cls_with_init: RayClassWithInitArgs = None,
|
||||
bin_pack: bool = True,
|
||||
name_prefix: str = None,
|
||||
detached=False,
|
||||
worker_names=None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(resource_pool=resource_pool, **kwargs)
|
||||
self.ray_cls_with_init = ray_cls_with_init
|
||||
self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
|
||||
|
||||
if worker_names is not None:
|
||||
assert self._is_init_with_detached_workers
|
||||
self._worker_names = worker_names
|
||||
|
||||
if self._is_init_with_detached_workers:
|
||||
self._init_with_detached_workers(worker_names=worker_names)
|
||||
else:
|
||||
self._init_with_resource_pool(resource_pool=resource_pool,
|
||||
ray_cls_with_init=ray_cls_with_init,
|
||||
bin_pack=bin_pack,
|
||||
detached=detached)
|
||||
|
||||
if ray_cls_with_init is not None:
|
||||
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
|
||||
|
||||
def _is_worker_alive(self, worker: ray.actor.ActorHandle):
|
||||
worker_state_dict = get_actor(worker._actor_id.hex())
|
||||
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
|
||||
|
||||
def _init_with_detached_workers(self, worker_names):
|
||||
workers = [ray.get_actor(name=name) for name in worker_names]
|
||||
self._workers = workers
|
||||
self._world_size = len(worker_names)
|
||||
|
||||
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
|
||||
use_gpu = resource_pool.use_gpu
|
||||
|
||||
strategy = "PACK"
|
||||
if bin_pack:
|
||||
strategy = "STRICT_PACK"
|
||||
pgs = resource_pool.get_placement_groups(strategy=strategy)
|
||||
world_size = resource_pool.world_size
|
||||
self._world_size = world_size
|
||||
# cia.add_kwarg("_world_size", world_size)
|
||||
num_gpus = 1 / resource_pool.max_collocate_count
|
||||
|
||||
rank = -1
|
||||
for pg_idx, local_world_size in enumerate(resource_pool.store):
|
||||
pg = pgs[pg_idx]
|
||||
assert local_world_size <= pg.bundle_count, \
|
||||
f"when generating for {self.name_prefix}, for the "
|
||||
for local_rank in range(local_world_size):
|
||||
rank += 1
|
||||
|
||||
# we pass in environment variable at option so that Worker can use environment variable to set
|
||||
env_vars = {
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'RANK': str(rank),
|
||||
'WG_PREFIX': self.name_prefix,
|
||||
'WG_BACKEND': 'ray',
|
||||
'RAY_LOCAL_WORLD_SIZE': str(local_world_size),
|
||||
'RAY_LOCAL_RANK': str(local_rank),
|
||||
}
|
||||
if rank != 0:
|
||||
env_vars['MASTER_ADDR'] = self._master_addr
|
||||
env_vars['MASTER_PORT'] = self._master_port
|
||||
|
||||
import re
|
||||
cia_name = type(ray_cls_with_init.cls).__name__
|
||||
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
|
||||
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
|
||||
name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
|
||||
|
||||
ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name})
|
||||
|
||||
if detached:
|
||||
ray_cls_with_init.update_options({'lifetime': 'detached'})
|
||||
|
||||
# create a worker
|
||||
worker = ray_cls_with_init(placement_group=pg,
|
||||
placement_group_bundle_idx=local_rank,
|
||||
use_gpu=use_gpu,
|
||||
num_gpus=num_gpus)
|
||||
self._workers.append(worker)
|
||||
self._worker_names.append(name)
|
||||
|
||||
if rank == 0:
|
||||
register_center_actor = None
|
||||
for _ in range(120):
|
||||
if f"{self.name_prefix}_register_center" not in list_named_actors():
|
||||
time.sleep(1)
|
||||
else:
|
||||
register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
|
||||
break
|
||||
assert register_center_actor is not None, f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
|
||||
rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
|
||||
self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT']
|
||||
# print(f"rank_zero_info: {rank_zero_info}")
|
||||
# print(f"master_addr: {self._master_addr}, master_port: {self._master_port}")
|
||||
|
||||
@property
|
||||
def worker_names(self):
|
||||
return self._worker_names
|
||||
|
||||
@classmethod
|
||||
def from_detached(cls, worker_names=None, ray_cls_with_init=None):
|
||||
worker_group = cls(resource_pool=None,
|
||||
ray_cls_with_init=ray_cls_with_init,
|
||||
name_prefix=None,
|
||||
worker_names=worker_names)
|
||||
return worker_group
|
||||
|
||||
def spawn(self, prefix_set):
|
||||
"""
|
||||
spawn to a dictionary of worker groups, each with a subset of method with prefix.
|
||||
|
||||
"""
|
||||
|
||||
def _rebind_actor_methods(worker_group, actor_name):
|
||||
"""
|
||||
bind the method with actor_prefix to its original name
|
||||
"""
|
||||
prefix: str = actor_name + '_'
|
||||
for method_name in dir(worker_group):
|
||||
if method_name.startswith(prefix):
|
||||
# only valid when Python >= 3.9
|
||||
original_method_name = method_name.removeprefix(prefix)
|
||||
method = getattr(worker_group, method_name)
|
||||
setattr(worker_group, original_method_name, method)
|
||||
|
||||
new_worker_group_dict = {}
|
||||
for prefix in prefix_set:
|
||||
new_worker_group = self.from_detached(worker_names=self._worker_names,
|
||||
ray_cls_with_init=self.ray_cls_with_init)
|
||||
|
||||
_rebind_actor_methods(new_worker_group, prefix)
|
||||
new_worker_group_dict[prefix] = new_worker_group
|
||||
return new_worker_group_dict
|
||||
|
||||
def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
|
||||
return ray.get(self.execute_all_async(method_name, **args, **kwargs))
|
||||
|
||||
def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
|
||||
remote_call = getattr(self._workers[0], method_name)
|
||||
return remote_call.remote(*args, **kwargs)
|
||||
|
||||
def execute_rank_zero(self, method_name: str, *args, **kwargs):
|
||||
return self.execute_rank_zero_async(method_name, *args, **kwargs)
|
||||
|
||||
def execute_all(self, method_name: str, *args, **kwargs):
|
||||
return self.execute_all_async(method_name, *args, **kwargs)
|
||||
|
||||
def execute_all_sync(self, method_name: str, *args, **kwargs):
|
||||
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
|
||||
|
||||
def execute_all_async(self, method_name: str, *args, **kwargs):
|
||||
# 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把
|
||||
# list 中的每一个分别发到对应的 worker 上去
|
||||
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
|
||||
length = len(self._workers)
|
||||
if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
|
||||
if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
|
||||
# print(f"splitting args and kwargs into {length} shards")
|
||||
result = []
|
||||
for i in range(length):
|
||||
sliced_args = tuple(arg[i] for arg in args)
|
||||
sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
|
||||
remote_call = getattr(self._workers[i], method_name)
|
||||
result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
|
||||
return result
|
||||
|
||||
return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
|
||||
|
||||
@property
|
||||
def master_address(self):
|
||||
return self._master_addr
|
||||
|
||||
@property
|
||||
def master_port(self):
|
||||
return self._master_port
|
||||
|
||||
@property
|
||||
def workers(self):
|
||||
return self._workers
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
|
||||
"""
|
||||
Utilities that enables creating workers inside the same ray.Actor,
|
||||
with code written in separate ray.Actors.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
from verl.single_controller.base.decorator import MAGIC_ATTR
|
||||
import os
|
||||
|
||||
|
||||
def _bind_workers_method_to_parent(cls, key, user_defined_cls):
|
||||
"""
|
||||
Binds the methods of each worker to the WorkerDict.
|
||||
Note that we only bind public methods that are decorated by register
|
||||
"""
|
||||
for method_name in dir(user_defined_cls):
|
||||
try:
|
||||
method = getattr(user_defined_cls, method_name)
|
||||
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
||||
except Exception as e:
|
||||
# if it is a property, it will fail because Class doesn't have instance property
|
||||
continue
|
||||
|
||||
if hasattr(method, MAGIC_ATTR):
|
||||
|
||||
def generate_function(name):
|
||||
|
||||
def func(self, *args, **kwargs):
|
||||
# dispatch to the actual worker
|
||||
return getattr(self.worker_dict[key], name)(*args, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
func = generate_function(method_name)
|
||||
# pass MAGIC_ATTR for outer worker group
|
||||
setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
|
||||
try:
|
||||
method_name_with_prefix = key + '_' + method_name
|
||||
setattr(cls, method_name_with_prefix, func)
|
||||
# print(f'Binding {method_name_with_prefix}')
|
||||
except Exception as e:
|
||||
raise ValueError(f'Fail to set method_name {method_name}')
|
||||
|
||||
|
||||
def _unwrap_ray_remote(cls):
|
||||
if hasattr(cls, '__ray_actor_class__'):
|
||||
cls = cls.__ray_actor_class__
|
||||
return cls
|
||||
|
||||
|
||||
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
|
||||
"""
|
||||
This function should return a class instance that delegates the calls to every
|
||||
cls in cls_dict
|
||||
"""
|
||||
cls_dict = {}
|
||||
init_args_dict = {}
|
||||
worker_cls = None
|
||||
for key, cls in class_dict.items():
|
||||
if worker_cls == None:
|
||||
worker_cls = cls.cls.__ray_actor_class__.__base__
|
||||
else:
|
||||
assert worker_cls == cls.cls.__ray_actor_class__.__base__, \
|
||||
'the worker class should be the same when share the same process'
|
||||
cls_dict[key] = cls.cls
|
||||
init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs}
|
||||
|
||||
assert cls_dict.keys() == init_args_dict.keys()
|
||||
|
||||
# TODO: create a class with customizable name
|
||||
class WorkerDict(worker_cls):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.worker_dict = {}
|
||||
for key, user_defined_cls in cls_dict.items():
|
||||
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
||||
# directly instantiate the class without remote
|
||||
with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
|
||||
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
|
||||
**init_args_dict[key].get('kwargs', {}))
|
||||
|
||||
# now monkey-patch the methods from inner class to WorkerDict
|
||||
for key, user_defined_cls in cls_dict.items():
|
||||
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
||||
_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
|
||||
|
||||
remote_cls = ray.remote(WorkerDict)
|
||||
remote_cls = RayClassWithInitArgs(cls=remote_cls)
|
||||
return remote_cls
|
||||
62
verl/single_controller/ray/megatron.py
Normal file
62
verl/single_controller/ray/megatron.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
|
||||
from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
|
||||
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||
|
||||
|
||||
# NOTE(sgm): for opensource megatron-core
|
||||
class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||
"""
|
||||
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
|
||||
so that the dispatcher can use it to dispatch data.
|
||||
"""
|
||||
|
||||
def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
|
||||
super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
|
||||
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
|
||||
self._megatron_global_info: DistGlobalInfo = ray.get(
|
||||
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
|
||||
|
||||
|
||||
class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||
"""
|
||||
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
|
||||
so that the dispatcher can use it to dispatch data.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
resource_pool: RayResourcePool,
|
||||
ray_cls_with_init: RayClassWithInitArgs,
|
||||
default_megatron_kwargs: Dict = None,
|
||||
**kwargs):
|
||||
super().__init__(resource_pool=resource_pool,
|
||||
ray_cls_with_init=ray_cls_with_init,
|
||||
default_megatron_kwargs=default_megatron_kwargs,
|
||||
**kwargs)
|
||||
self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)
|
||||
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
|
||||
self._megatron_global_info: DistGlobalInfo = ray.get(
|
||||
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
|
||||
|
||||
def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
|
||||
# after super, we will call init of each worker
|
||||
if not self._is_init_with_detached_workers:
|
||||
# only init_megatron if the WorkerGroup is created from scratch
|
||||
self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs)
|
||||
1
verl/single_controller/version/version
Normal file
1
verl/single_controller/version/version
Normal file
@@ -0,0 +1 @@
|
||||
0.0.2
|
||||
Reference in New Issue
Block a user