Files
Search-R1/verl/workers/sharding_manager/fsdp_ulysses.py
PeterGriffinJin 068516be64 Initial commit
2025-02-28 15:16:19 +00:00

88 lines
3.5 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
"""
from typing import Optional
from .base import BaseShardingManager
import random
from torch.distributed.device_mesh import DeviceMesh
from verl.utils.torch_functional import allgather_dict_tensors
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
import numpy as np
import torch
import torch.distributed
from verl import DataProto
class FSDPUlyssesShardingManager(BaseShardingManager):
"""
Sharding manager to support data resharding when using FSDP + Ulysses
"""
def __init__(self, device_mesh: DeviceMesh):
super().__init__()
self.device_mesh = device_mesh
self.seed_offset = 12345
def __enter__(self):
if self.device_mesh is not None:
# We have a global SP group
# so we have to change to use model-specific sp group
self.prev_sp_group = get_ulysses_sequence_parallel_group()
set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
# TODO: check how to set seed for each model
def __exit__(self, exc_type, exc_value, traceback):
# restore random states
if self.device_mesh is not None:
# revert to previous sp group
set_ulysses_sequence_parallel_group(self.prev_sp_group)
# TODO: check how to set seed for each model
def preprocess_data(self, data: DataProto) -> DataProto:
"""
AllGather data from sp region
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
In Ulysses, we need to make sure the same data is used across a SP group
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
group = self.device_mesh['sp'].get_group()
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(sp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
}
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""
Split the data to follow FSDP partition
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
sp_rank = self.device_mesh['sp'].get_local_rank()
data = data.chunk(chunks=sp_size)[sp_rank]
return data