Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .worker import Worker
from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool

View File

@@ -0,0 +1,410 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from functools import wraps
from typing import Dict, List, Tuple
from types import FunctionType
from verl.protocol import DataProtoFuture
# here we add a magic number of avoid user-defined function already have this attribute
MAGIC_ATTR = 'attrs_3141562937'
class Dispatch(Enum):
RANK_ZERO = 0
ONE_TO_ALL = 1
ALL_TO_ALL = 2
MEGATRON_COMPUTE = 3
MEGATRON_PP_AS_DP = 4
MEGATRON_PP_ONLY = 5
MEGATRON_COMPUTE_PROTO = 6
MEGATRON_PP_AS_DP_PROTO = 7
DP_COMPUTE = 8
DP_COMPUTE_PROTO = 9
DP_COMPUTE_PROTO_WITH_FUNC = 10
DP_COMPUTE_METRIC = 11
class Execute(Enum):
ALL = 0
RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
from verl.protocol import DataProto, DataProtoFuture
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks)
return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs):
return args, kwargs
def collect_all_to_all(worker_group, output):
return output
def dispatch_megatron_compute(worker_group, *args, **kwargs):
"""
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group,
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
all_args = []
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
transformed_args.append(arg[local_dp_rank])
all_args.append(transformed_args)
all_args = tuple(all_args)
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
transformed_v.append(v[local_dp_rank])
all_kwargs[k] = transformed_v
return all_args, all_kwargs
def collect_megatron_compute(worker_group, output):
"""
Only collect the data from the tp=0 and pp=last and every dp ranks
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
pp_size = worker_group.get_megatron_global_info().pp_size
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
output_in_dp.append(output[global_rank])
return output_in_dp
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
"""
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
def _concat_data_proto_or_future(output: List):
from verl.protocol import DataProto, DataProtoFuture
import ray
# make sure all the elements in output has the same type
for o in output:
assert type(o) == type(output[0])
o = output[0]
if isinstance(o, DataProto):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
else:
raise NotImplementedError
def collect_megatron_compute_data_proto(worker_group, output):
"""
Each output must be a DataProto. We concat the dim=0 of output
"""
from verl.protocol import DataProto
import ray
output = collect_megatron_compute(worker_group, output)
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
return _concat_data_proto_or_future(output)
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
"""
treat pp as dp.
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_size = worker_group.pp_size
dp_size = worker_group.dp_size
pp_dp_size = pp_size * dp_size
all_args = []
for arg in args:
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
# compute the rank in arg. Note that the order is dp then pp
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
# dispatch: pp_allgther: collect:
# dp 0 1 2 3 dp 0 1 2 3
# pp +---------+ pp +-------------+
# 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
# 1 | B D F H | 1 | AB CD EF GH |
# +---------+ +-------------+
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
transformed_args.append(arg[arg_rank])
all_args.append(transformed_args)
all_args = tuple(all_args)
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_size, f'expect len(v)=={pp_dp_size}, got {len(v)}'
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
# compute the rank in arg. Note that the order is dp then pp
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
transformed_v.append(v[arg_rank])
all_kwargs[k] = transformed_v
return all_args, all_kwargs
def collect_megatron_pp_as_dp(worker_group, output):
"""
treat pp as dp. Only collect data on tp=0
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
output_in_dp.append(output[global_rank])
return output_in_dp
def collect_megatron_pp_only(worker_group, output):
"""
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
"""
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_pp = []
for global_rank in range(worker_group.world_size):
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
output_in_pp.append(output[global_rank])
return output_in_pp
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_dp_size = worker_group.dp_size * worker_group.pp_size
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
from verl.protocol import DataProto
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output = collect_megatron_pp_as_dp(worker_group, output)
return _concat_data_proto_or_future(output)
def dispatch_dp_compute(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
return args, kwargs
def collect_dp_compute(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output):
from verl.protocol import DataProto
import ray
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
def get_predefined_dispatch_fn(dispatch_mode):
predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: {
'dispatch_fn': dispatch_one_to_all,
'collect_fn': collect_all_to_all,
},
Dispatch.ALL_TO_ALL: {
'dispatch_fn': dispatch_all_to_all,
'collect_fn': collect_all_to_all,
},
Dispatch.MEGATRON_COMPUTE: {
'dispatch_fn': dispatch_megatron_compute,
'collect_fn': collect_megatron_compute,
},
Dispatch.MEGATRON_PP_AS_DP: {
'dispatch_fn': dispatch_megatron_pp_as_dp,
'collect_fn': collect_megatron_pp_as_dp,
},
Dispatch.MEGATRON_PP_ONLY: {
'dispatch_fn': dispatch_one_to_all,
'collect_fn': collect_megatron_pp_only
},
Dispatch.MEGATRON_COMPUTE_PROTO: {
'dispatch_fn': dispatch_megatron_compute_data_proto,
'collect_fn': collect_megatron_compute_data_proto
},
Dispatch.MEGATRON_PP_AS_DP_PROTO: {
'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
'collect_fn': collect_megatron_pp_as_dp_data_proto
},
Dispatch.DP_COMPUTE: {
'dispatch_fn': dispatch_dp_compute,
'collect_fn': collect_dp_compute
},
Dispatch.DP_COMPUTE_PROTO: {
'dispatch_fn': dispatch_dp_compute_data_proto,
'collect_fn': collect_dp_compute_data_proto
},
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
'collect_fn': collect_dp_compute_data_proto
},
Dispatch.DP_COMPUTE_METRIC: {
'dispatch_fn': dispatch_dp_compute_data_proto,
'collect_fn': collect_dp_compute
}
}
return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode):
"""
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
"""
predefined_execute_mode_fn = {
Execute.ALL: {
'execute_fn_name': 'execute_all'
},
Execute.RANK_ZERO: {
'execute_fn_name': 'execute_rank_zero'
}
}
return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode):
assert isinstance(dispatch_mode,
(Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}'
if isinstance(dispatch_mode, Dict):
necessary_keys = ['dispatch_fn', 'collect_fn']
for key in necessary_keys:
assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary'
def _check_execute_mode(execute_mode):
assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}'
def _materialize_futures(*args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, DataProtoFuture):
arg = arg.get()
# add more type to materialize
new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture):
kwargs[k] = v.get()
new_args = tuple(new_args)
return new_args, kwargs
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
_check_dispatch_mode(dispatch_mode=dispatch_mode)
_check_execute_mode(execute_mode=execute_mode)
def decorator(func):
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
args, kwargs = _materialize_futures(*args, **kwargs)
return func(*args, **kwargs)
attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
setattr(inner, MAGIC_ATTR, attrs)
return inner
return decorator

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,39 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
class MegatronWorker(Worker):
def __init__(self, cuda_visible_devices=None) -> None:
super().__init__(cuda_visible_devices)
def get_megatron_global_info(self):
from megatron.core import parallel_state as mpu
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
return info
def get_megatron_rank_info(self):
from megatron.core import parallel_state as mpu
tp_rank = mpu.get_tensor_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
return info

View File

@@ -0,0 +1,51 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from .worker import DistRankInfo, DistGlobalInfo
from verl.single_controller.base import ResourcePool, WorkerGroup
class MegatronWorkerGroup(WorkerGroup):
def __init__(self, resource_pool: ResourcePool, **kwargs):
super().__init__(resource_pool=resource_pool, **kwargs)
self._megatron_rank_info = None
self._megatron_global_info: DistGlobalInfo = None
def init_megatron(self, default_megatron_kwargs: Dict = None):
raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
return self._megatron_rank_info[rank]
@property
def tp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.tp_size
@property
def dp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.dp_size
@property
def pp_size(self):
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
return self._megatron_global_info.pp_size
def get_megatron_global_info(self):
return self._megatron_global_info

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,29 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
@ray.remote
class WorkerGroupRegisterCenter:
def __init__(self, rank_zero_info):
self.rank_zero_info = rank_zero_info
def get_rank_zero_info(self):
return self.rank_zero_info
def create_worker_group_register_center(name, info):
return WorkerGroupRegisterCenter.options(name=name).remote(info)

View File

@@ -0,0 +1,186 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
the class for Worker
"""
import os
import socket
from dataclasses import dataclass
from verl.single_controller.base.decorator import register, Dispatch, Execute
@dataclass
class DistRankInfo:
tp_rank: int
dp_rank: int
pp_rank: int
@dataclass
class DistGlobalInfo:
tp_size: int
dp_size: int
pp_size: int
class WorkerHelper:
def _get_node_ip(self):
def get_node_ip_by_sdk():
if os.getenv("WG_BACKEND", None) == "ray":
import ray
return ray._private.services.get_node_ip_address()
elif os.getenv("WG_BACKEND", None) == "torch_rpc":
from verl.single_controller.torchrpc.k8s_client import get_ip_addr
return get_ip_addr()
return None
host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6
host_ip_by_sdk = get_node_ip_by_sdk()
host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip
def _get_free_port(self):
with socket.socket() as sock:
sock.bind(('', 0))
return sock.getsockname()[1]
def get_availale_master_addr_port(self):
return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self):
return
class WorkerMeta:
keys = [
"WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"
]
def __init__(self, store) -> None:
self._store = store
def to_dict(self):
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
# note that here we use int to distinguish
disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0))
if disable_worker_init:
return instance
rank = os.environ.get("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__:
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
return instance
def _configure_before_init(self, register_center_name: str, rank: int):
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
if rank == 0:
master_addr, master_port = self.get_availale_master_addr_port()
rank_zero_info = {
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
}
if os.getenv("WG_BACKEND", None) == "ray":
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
self.register_center = create_worker_group_register_center(name=register_center_name,
info=rank_zero_info)
os.environ.update(rank_zero_info)
def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
import os
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
self._rank = rank
self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
store = {
'_world_size': world_size,
'_rank': rank,
'_local_world_size': local_world_size,
'_local_rank': local_rank,
'_master_addr': master_addr,
'_master_port': master_port
}
if cuda_visible_devices is not None:
store['_cuda_visible_devices'] = cuda_visible_devices
meta = WorkerMeta(store=store)
self._configure_with_meta(meta=meta)
def _configure_with_meta(self, meta: WorkerMeta):
"""
This function should only be called inside by WorkerGroup
"""
assert isinstance(meta, WorkerMeta)
self.__dict__.update(meta.to_dict()) # this is hacky
# print(f"__dict__: {self.__dict__}")
for key in WorkerMeta.keys:
val = self.__dict__.get(f"_{key.lower()}", None)
if val is not None:
# print(f"set {key} to {val}")
os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace(
"]", "") if self._master_addr else ""
def get_master_addr_port(self):
return self._master_addr, self._master_port
def get_cuda_visible_devices(self):
import os
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
def execute_with_func_generator(self, func, *args, **kwargs):
ret_proto = func(self, *args, **kwargs)
return ret_proto
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
def execute_func_rank_zero(self, func, *args, **kwargs):
result = func(*args, **kwargs)
return result

View File

@@ -0,0 +1,196 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
the class of WorkerGroup
"""
import logging
import threading
import signal
import time
from typing import List, Any, Callable, Dict
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
if process_on_nodes is None:
process_on_nodes = []
self._store = process_on_nodes
self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
def add_node(self, process_count):
self._store.append(process_count)
@property
def world_size(self):
return sum(self._store)
def __call__(self) -> Any:
return self._store
@property
def store(self):
return self._store
def local_world_size_list(self) -> List[int]:
nested_local_world_size_list = [
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
]
return [item for row in nested_local_world_size_list for item in row]
def local_rank_list(self) -> List[int]:
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
return [item for row in nested_local_rank_list for item in row]
class ClassWithInitArgs:
"""
This class stores a class constructor and the args/kwargs to construct the class.
It is used to instantiate the remote class.
"""
def __init__(self, cls, *args, **kwargs) -> None:
self.cls = cls
self.args = args
self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True:
for worker in workers:
if not is_alive(worker):
logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time)
class WorkerGroup:
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False
if resource_pool is not None:
# handle the case when WorkGroup is attached to an existing one
self._procecss_dispatch_config = resource_pool()
else:
self._procecss_dispatch_config = None
self._workers = []
self._worker_names = []
self._master_addr = None
self._master_port = None
self._checker_thread: threading.Thread = None
def _is_worker_alive(self, worker):
raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
def _block_until_all_workers_alive(self) -> None:
while True:
all_state = [self._is_worker_alive(worker) for worker in self._workers]
if False in all_state:
time.sleep(1)
else:
break
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
# before starting checking worker aliveness, make sure all workers are already alive
self._block_until_all_workers_alive()
self._checker_thread = threading.Thread(target=check_workers_alive,
args=(self._workers, self._is_worker_alive, every_n_seconds))
self._checker_thread.start()
@property
def world_size(self):
return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator):
"""
Bind the worker method to the WorkerGroup
"""
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
except Exception as e:
# if it is a property, it will fail because Class doesn't have instance property
continue
if hasattr(method, MAGIC_ATTR):
# this method is decorated by register
attribute = getattr(method, MAGIC_ATTR)
assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}'
assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key'
dispatch_mode = attribute['dispatch_mode']
execute_mode = attribute['execute_mode']
blocking = attribute['blocking']
# get dispatch fn
if isinstance(dispatch_mode, Dispatch):
# get default dispatch fn
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
dispatch_fn = fn['dispatch_fn']
collect_fn = fn['collect_fn']
else:
assert isinstance(dispatch_mode, dict)
assert 'dispatch_fn' in dispatch_mode
assert 'collect_fn' in dispatch_mode
dispatch_fn = dispatch_mode['dispatch_fn']
collect_fn = dispatch_mode['collect_fn']
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode['execute_fn_name']
# get execute_fn from string
try:
execute_fn = getattr(self, wg_execute_fn_name)
assert callable(execute_fn), 'execute_fn must be callable'
except Exception as e:
print(f'execute_fn {wg_execute_fn_name} is invalid')
raise
# bind a new method to the RayWorkerGroup
func = func_generator(self,
method_name,
dispatch_fn=dispatch_fn,
collect_fn=collect_fn,
execute_fn=execute_fn,
blocking=blocking)
try:
setattr(self, method_name, func)
except Exception as e:
raise ValueError(f'Fail to set method_name {method_name}')