Initial commit
This commit is contained in:
288
verl/utils/ulysses.py
Normal file
288
verl/utils/ulysses.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for DeepSpeed Ulysses Sequence Parallelism.
|
||||
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509
|
||||
Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||||
"""
|
||||
from typing import Any, Optional, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
||||
"""
|
||||
Set ulysses sequence parallel process group.
|
||||
"""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
||||
"""
|
||||
Get ulysses sequence parallel process group.
|
||||
"""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
||||
"""
|
||||
Get ulysses sequence parallel world size.
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_world_size(group) if group else 1
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
||||
"""
|
||||
Get ulysses sequence parallel rank.
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_rank(group) if group else 0
|
||||
|
||||
|
||||
def gather_seq_scatter_heads(
|
||||
x: Tensor,
|
||||
seq_dim: int,
|
||||
head_dim: int,
|
||||
unpadded_dim_size: int = 0,
|
||||
group: ProcessGroup = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
A func to sync embedding input with alltoall in sequence parallel
|
||||
gather sequence dimension and scatter head dim:
|
||||
e.g. seq_dim: 1, head_dim: 2
|
||||
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
if not group:
|
||||
return x
|
||||
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
||||
x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
|
||||
if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
|
||||
padding_size = x.size(seq_dim) - unpadded_dim_size
|
||||
x = _unpad_tensor(x, seq_dim, padding_size)
|
||||
return x
|
||||
|
||||
|
||||
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
|
||||
"""
|
||||
A func to sync attention result with alltoall in sequence parallel
|
||||
gather head dimension and scatter seq dim:
|
||||
e.g. seq_dim: 1, head_dim: 2
|
||||
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
if not group:
|
||||
return x
|
||||
dim_size = x.size(seq_dim)
|
||||
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
||||
if dim_size % sp_world != 0:
|
||||
padding_size = sp_world - (dim_size % sp_world)
|
||||
x = _pad_tensor(x, seq_dim, padding_size)
|
||||
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
|
||||
|
||||
|
||||
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
||||
shape = list(x.shape)
|
||||
shape[dim] = padding_size
|
||||
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=dim)
|
||||
|
||||
|
||||
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
||||
slc = [slice(None)] * len(x.shape)
|
||||
slc[dim] = slice(0, -padding_size)
|
||||
return x[slc]
|
||||
|
||||
|
||||
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_world_size = dist.get_world_size(group)
|
||||
sp_rank = get_ulysses_sequence_parallel_rank()
|
||||
dim_size = x.size(dim)
|
||||
# pad before slice
|
||||
if padding and dim_size % sp_world_size:
|
||||
padding_size = sp_world_size - (dim_size % sp_world_size)
|
||||
x = _pad_tensor(x, dim, padding_size)
|
||||
# slice the input tensor
|
||||
parts = x.size(dim) // sp_world_size
|
||||
slc = [slice(None)] * len(x.shape)
|
||||
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
|
||||
return x[slc].contiguous()
|
||||
|
||||
|
||||
def all_to_all_tensor(
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
|
||||
def wait():
|
||||
comm.wait()
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
return wait
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_world_size = dist.get_world_size(group=group)
|
||||
output_shape = list(local_tensor.shape)
|
||||
output_shape[0] = output_shape[0] * sp_world_size
|
||||
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)
|
||||
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)
|
||||
return output
|
||||
|
||||
|
||||
class SeqAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
async_op: bool = False,
|
||||
) -> Tensor:
|
||||
ctx.group = group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.async_op = async_op
|
||||
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
||||
if ctx.async_op:
|
||||
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
|
||||
else:
|
||||
input_t = grad_output[0]
|
||||
return (
|
||||
None,
|
||||
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class Gather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
local_tensor: Tensor,
|
||||
gather_dim: int,
|
||||
grad_scaler: bool = True,
|
||||
async_op=False) -> Tensor:
|
||||
ctx.group = group
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.grad_scaler = grad_scaler
|
||||
ctx.async_op = async_op
|
||||
|
||||
sp_world_size = dist.get_world_size(group=group)
|
||||
ctx.sp_world_size = sp_world_size
|
||||
|
||||
sp_rank = dist.get_rank(group=group)
|
||||
ctx.sp_rank = sp_rank
|
||||
|
||||
local_shape = list(local_tensor.size())
|
||||
split_size = local_shape[0]
|
||||
part_size = local_shape[gather_dim] # store original size
|
||||
ctx.part_size = part_size
|
||||
|
||||
output = all_gather_tensor(local_tensor, group, async_op)
|
||||
return torch.cat(output.split(split_size, dim=0), dim=gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_output: Tensor) -> Any:
|
||||
if ctx.grad_scaler:
|
||||
grad_output = grad_output * ctx.sp_world_size
|
||||
return (None, grad_output.split(ctx.part_size,
|
||||
dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), None, None, None, None)
|
||||
|
||||
|
||||
def gather_outpus_and_unpad(x: Tensor,
|
||||
gather_dim: int,
|
||||
unpad_dim: int = None,
|
||||
padding_size: int = 0,
|
||||
grad_scaler: bool = True,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
if group == None:
|
||||
return x
|
||||
x = Gather.apply(group, x, gather_dim, grad_scaler)
|
||||
if unpad_dim is not None:
|
||||
assert isinstance(padding_size, int), 'padding size is not given or is not an integer'
|
||||
if padding_size == 0:
|
||||
return x
|
||||
x = _unpad_tensor(x, unpad_dim, padding_size)
|
||||
return x
|
||||
|
||||
|
||||
def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor,
|
||||
position_ids_rmpad: Optional[torch.Tensor] = None,
|
||||
sp_size: int = 1):
|
||||
"""
|
||||
Pad and slice input_ids to be divisible by sp_size
|
||||
Pad position_ids to be divisible by sp_size.
|
||||
|
||||
Note both input_ids_rmpad and position_ids_rmpad will be padded,
|
||||
but only input_ids will be sliced.
|
||||
|
||||
The is the utility of pre-forward for ulysses sequence parallelism
|
||||
|
||||
Args:
|
||||
input_ids_rmpad: shape of [bsz, seqlen]
|
||||
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
|
||||
sp_size (int): ulysses sequence parallelism size
|
||||
|
||||
Returns:
|
||||
torch.Tensor: padded and sliced input_ids
|
||||
torch.Tensor: padded and sliced position_ids
|
||||
int: pad size
|
||||
"""
|
||||
if position_ids_rmpad is not None:
|
||||
assert position_ids_rmpad.size(0) == 1
|
||||
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
|
||||
if sp_size <= 1:
|
||||
return input_ids_rmpad, position_ids_rmpad, 0
|
||||
_, total_seq_len = input_ids_rmpad.shape
|
||||
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
|
||||
if pad_size > 0:
|
||||
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
|
||||
if position_ids_rmpad is not None:
|
||||
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
|
||||
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
|
||||
# we don't need to slice position ids
|
||||
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
|
||||
return input_ids_rmpad, position_ids_rmpad, pad_size
|
||||
Reference in New Issue
Block a user