Initial commit
This commit is contained in:
196
verl/single_controller/base/worker_group.py
Normal file
196
verl/single_controller/base/worker_group.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
the class of WorkerGroup
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Any, Callable, Dict
|
||||
|
||||
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
||||
|
||||
|
||||
class ResourcePool:
|
||||
|
||||
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
|
||||
if process_on_nodes is None:
|
||||
process_on_nodes = []
|
||||
self._store = process_on_nodes
|
||||
self.max_collocate_count = max_collocate_count
|
||||
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
|
||||
|
||||
def add_node(self, process_count):
|
||||
self._store.append(process_count)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return sum(self._store)
|
||||
|
||||
def __call__(self) -> Any:
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
return self._store
|
||||
|
||||
def local_world_size_list(self) -> List[int]:
|
||||
nested_local_world_size_list = [
|
||||
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
|
||||
]
|
||||
return [item for row in nested_local_world_size_list for item in row]
|
||||
|
||||
def local_rank_list(self) -> List[int]:
|
||||
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
|
||||
return [item for row in nested_local_rank_list for item in row]
|
||||
|
||||
|
||||
class ClassWithInitArgs:
|
||||
"""
|
||||
This class stores a class constructor and the args/kwargs to construct the class.
|
||||
It is used to instantiate the remote class.
|
||||
"""
|
||||
|
||||
def __init__(self, cls, *args, **kwargs) -> None:
|
||||
self.cls = cls
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
# def add_arg(self, arg):
|
||||
# self.args += (arg,)
|
||||
|
||||
# def add_kwarg(self, key, value):
|
||||
# self.kwargs[key] = value
|
||||
|
||||
def __call__(self) -> Any:
|
||||
return self.cls(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
||||
import time
|
||||
while True:
|
||||
for worker in workers:
|
||||
if not is_alive(worker):
|
||||
logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
|
||||
signal.raise_signal(signal.SIGABRT)
|
||||
time.sleep(gap_time)
|
||||
|
||||
|
||||
class WorkerGroup:
|
||||
|
||||
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
|
||||
self._is_init_with_detached_workers = True if resource_pool is None else False
|
||||
|
||||
if resource_pool is not None:
|
||||
# handle the case when WorkGroup is attached to an existing one
|
||||
self._procecss_dispatch_config = resource_pool()
|
||||
else:
|
||||
self._procecss_dispatch_config = None
|
||||
|
||||
self._workers = []
|
||||
self._worker_names = []
|
||||
|
||||
self._master_addr = None
|
||||
self._master_port = None
|
||||
|
||||
self._checker_thread: threading.Thread = None
|
||||
|
||||
def _is_worker_alive(self, worker):
|
||||
raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
|
||||
|
||||
def _block_until_all_workers_alive(self) -> None:
|
||||
while True:
|
||||
all_state = [self._is_worker_alive(worker) for worker in self._workers]
|
||||
if False in all_state:
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
|
||||
# before starting checking worker aliveness, make sure all workers are already alive
|
||||
self._block_until_all_workers_alive()
|
||||
|
||||
self._checker_thread = threading.Thread(target=check_workers_alive,
|
||||
args=(self._workers, self._is_worker_alive, every_n_seconds))
|
||||
self._checker_thread.start()
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return len(self._workers)
|
||||
|
||||
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
|
||||
# MegatronWorkerGroup, XperfWorkerGroup should skip
|
||||
|
||||
def _bind_worker_method(self, user_defined_cls, func_generator):
|
||||
"""
|
||||
Bind the worker method to the WorkerGroup
|
||||
"""
|
||||
|
||||
for method_name in dir(user_defined_cls):
|
||||
|
||||
try:
|
||||
method = getattr(user_defined_cls, method_name)
|
||||
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
||||
except Exception as e:
|
||||
# if it is a property, it will fail because Class doesn't have instance property
|
||||
continue
|
||||
|
||||
if hasattr(method, MAGIC_ATTR):
|
||||
# this method is decorated by register
|
||||
attribute = getattr(method, MAGIC_ATTR)
|
||||
assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}'
|
||||
assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key'
|
||||
|
||||
dispatch_mode = attribute['dispatch_mode']
|
||||
execute_mode = attribute['execute_mode']
|
||||
blocking = attribute['blocking']
|
||||
|
||||
# get dispatch fn
|
||||
if isinstance(dispatch_mode, Dispatch):
|
||||
# get default dispatch fn
|
||||
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
|
||||
dispatch_fn = fn['dispatch_fn']
|
||||
collect_fn = fn['collect_fn']
|
||||
else:
|
||||
assert isinstance(dispatch_mode, dict)
|
||||
assert 'dispatch_fn' in dispatch_mode
|
||||
assert 'collect_fn' in dispatch_mode
|
||||
dispatch_fn = dispatch_mode['dispatch_fn']
|
||||
collect_fn = dispatch_mode['collect_fn']
|
||||
|
||||
# get execute_fn_name
|
||||
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
|
||||
wg_execute_fn_name = execute_mode['execute_fn_name']
|
||||
|
||||
# get execute_fn from string
|
||||
try:
|
||||
execute_fn = getattr(self, wg_execute_fn_name)
|
||||
assert callable(execute_fn), 'execute_fn must be callable'
|
||||
except Exception as e:
|
||||
print(f'execute_fn {wg_execute_fn_name} is invalid')
|
||||
raise
|
||||
|
||||
# bind a new method to the RayWorkerGroup
|
||||
func = func_generator(self,
|
||||
method_name,
|
||||
dispatch_fn=dispatch_fn,
|
||||
collect_fn=collect_fn,
|
||||
execute_fn=execute_fn,
|
||||
blocking=blocking)
|
||||
|
||||
try:
|
||||
setattr(self, method_name, func)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Fail to set method_name {method_name}')
|
||||
Reference in New Issue
Block a user