Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

18
verl/utils/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import tokenizer
from .tokenizer import *
__all__ = tokenizer.__all__

23
verl/utils/config.py Normal file
View File

@@ -0,0 +1,23 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from omegaconf import DictConfig
def update_dict_with_config(dictionary: Dict, config: DictConfig):
for key in dictionary:
if hasattr(config, key):
dictionary[key] = getattr(config, key)

View File

@@ -0,0 +1,16 @@
# Dataset Format
## RLHF dataset
We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.
Math problems
```json
{
"data_source": "openai/gsm8k",
"prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": ["72"]
},
}
```

View File

@@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .rl_dataset import RLHFDataset
from .rm_dataset import RMDataset

View File

@@ -0,0 +1,155 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf import ListConfig
import os
from typing import List, Union
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
def collate_fn(data_list: list[dict]) -> dict:
tensors = {}
non_tensors = {}
for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
if key not in tensors:
tensors[key] = []
tensors[key].append(val)
else:
if key not in non_tensors:
non_tensors[key] = []
non_tensors[key].append(val)
for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)
for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
output = {}
output.update(tensors)
output.update(non_tensors)
return output
class RLHFDataset(Dataset):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(self,
parquet_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
prompt_key='prompt',
max_prompt_length=1024,
filter_prompts=True,
cache_dir='~/.cache/verl/rlhf',
chat_template_func=None,
return_raw_chat=False,
truncation='error'):
if not isinstance(parquet_files, (List, ListConfig)):
parquet_files = [parquet_files]
self.parquet_files = parquet_files
self.cache_dir = os.path.expanduser(cache_dir)
self.tokenizer = tokenizer
self.prompt_key = prompt_key
self.max_prompt_length = max_prompt_length
self.filter_prompts = filter_prompts
self.return_raw_chat = return_raw_chat
self.chat_template_func = chat_template_func
self.truncation = truncation
self._download()
self._read_files_and_tokenize()
def _download(self):
from verl.utils.fs import copy_local_path_from_hdfs
for i, parquet_file in enumerate(self.parquet_files):
self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.parquet_files:
# read parquet files and cache
dataframe = pd.read_parquet(parquet_file)
dataframes.append(dataframe)
self.dataframe = pd.concat(dataframes)
print(f'original dataset len: {len(self.dataframe)}')
# filter out too long prompts
tokenizer = self.tokenizer
prompt_key = self.prompt_key
# nvm if prompt is too long
# self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
# tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
# axis=1)]
print(f'filter dataset len: {len(self.dataframe)}')
def __len__(self):
return len(self.dataframe)
def __getitem__(self, item):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict = self.dataframe.iloc[item].to_dict()
chat = row_dict.pop(self.prompt_key)
if self.tokenizer.chat_template:
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
else:
prompt_with_chat_template = chat[0]['content']
# prompt_with_chat_template = chat
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation)
position_ids = compute_position_id_with_mask(attention_mask)
row_dict['input_ids'] = input_ids[0]
row_dict['attention_mask'] = attention_mask[0]
row_dict['position_ids'] = position_ids[0]
# encode prompts without chat template
if self.return_raw_chat:
row_dict['raw_prompt'] = chat.tolist()
# add index for each prompt
index = row_dict.get("extra_info", {}).get("index", 0)
row_dict["index"] = index
return row_dict

View File

@@ -0,0 +1,143 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Union
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from verl.utils import hf_tokenizer
def download_files_distributed(download_fn):
import torch.distributed
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
# download files
download_fn()
torch.distributed.barrier()
else:
# download anyway
download_fn()
class RMDataset(Dataset):
def __init__(self,
parquet_files: Union[str, List[str]],
tokenizer,
prompt_key='prompt',
chosen_key='chosen',
rejected_key='rejected',
max_length=1024,
add_eos=True,
cache_dir='~/.cache/verl/rm'):
if not isinstance(parquet_files, List):
parquet_files = [parquet_files]
self.parquet_files = parquet_files
self.cache_dir = os.path.expanduser(cache_dir)
if isinstance(tokenizer, str):
tokenizer = hf_tokenizer(tokenizer)
self.tokenizer = tokenizer
self.prompt_key = prompt_key
self.chosen_key = chosen_key
self.rejected_key = rejected_key
self.add_eos = add_eos
self.max_length = max_length
self._download()
self._read_files_and_tokenize()
def _download(self):
def _download_files():
from verl.utils.fs import copy, _is_non_local
os.makedirs(self.cache_dir, exist_ok=True)
assert os.path.exists(self.cache_dir)
for i, parquet_file in enumerate(self.parquet_files):
if _is_non_local(parquet_file):
dst = os.path.join(self.cache_dir, os.path.basename(parquet_file))
if not os.path.exists(dst):
copy(src=parquet_file, dst=dst)
self.parquet_files[i] = dst
download_files_distributed(_download_files)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.parquet_files:
# read parquet files and cache
dataframe = pd.read_parquet(parquet_file)
dataframes.append(dataframe)
self.dataframe = pd.concat(dataframes)
self.prompts = self.dataframe[self.prompt_key].tolist()
self.chosen_responses = self.dataframe[self.chosen_key].tolist()
self.rejected_responses = self.dataframe[self.rejected_key].tolist()
def __len__(self):
return len(self.prompts)
def _pad_to_length(self, input_ids, attention_mask):
curr_length = input_ids.shape[-1]
if curr_length < self.max_length:
input_ids = torch.cat(
(input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1)
attention_mask = torch.cat(
(attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)),
dim=-1)
elif curr_length > self.max_length:
input_ids = input_ids[:self.max_length]
attention_mask = attention_mask[:self.max_length]
return input_ids, attention_mask
def __getitem__(self, item):
prompt = self.prompts[item]
chosen_response = self.chosen_responses[item]
rejected_response = self.rejected_responses[item]
prompt_ids = self.tokenizer(prompt, return_tensors='pt')['input_ids'][0]
chosen_response_ids = self.tokenizer(chosen_response, return_tensors='pt')['input_ids'][0]
rejected_response_ids = self.tokenizer(rejected_response, return_tensors='pt')['input_ids'][0]
if self.add_eos:
chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1)
rejected_response_ids = torch.cat((rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])),
dim=-1)
chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1)
chosen_attention_mask = torch.ones_like(chosen_input_ids)
rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1)
rejected_attention_mask = torch.ones_like(rejected_input_ids)
chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask)
rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask)
input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0)
attention_mask = torch.stack((rejected_input_ids, rejected_attention_mask), dim=0)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
}

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .performance import log_gpu_memory_usage

View File

@@ -0,0 +1,30 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
import logging
def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
memory_allocated = torch.cuda.memory_allocated() / 1024**3
memory_reserved = torch.cuda.memory_reserved() / 1024**3
message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}'
if logger is None:
print(message)
else:
logger.log(msg=message, level=level)

View File

@@ -0,0 +1,108 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Trajectory tracker can be inserted into code to save the intermediate results.
The results will be dump to hdfs for offline comparison.
Each process will have a client that first move all the tensors to CPU
"""
from verl.utils.hdfs_io import makedirs, copy
import torch
import os
import ray
import io
import tempfile
from collections import deque
remote_copy = ray.remote(copy)
@ray.remote
def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):
filename = name + '.pth'
with tempfile.TemporaryDirectory() as tmpdirname:
local_filepath = os.path.join(tmpdirname, filename)
with open(local_filepath, 'wb') as f:
f.write(data.getbuffer())
# upload to hdfs
if verbose:
print(f'Saving {local_filepath} to {hdfs_dir}')
try:
copy(local_filepath, hdfs_dir)
except Exception as e:
print(e)
@ray.remote
class TrajectoryTracker():
def __init__(self, hdfs_dir, verbose) -> None:
self.hdfs_dir = hdfs_dir
makedirs(hdfs_dir)
self.verbose = verbose
self.handle = deque()
def dump(self, data: io.BytesIO, name):
# get a temp file and write to it
self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))
def wait_for_hdfs(self):
while len(self.handle) != 0:
future = self.handle.popleft()
ray.get(future)
def dump_data(data, name):
enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1'
if not enable:
return
buffer = io.BytesIO()
torch.save(data, buffer)
tracker = get_trajectory_tracker()
ray.get(tracker.dump.remote(buffer, name))
def get_trajectory_tracker():
hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None)
verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1'
assert hdfs_dir is not None
tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True,
lifetime="detached").remote(hdfs_dir, verbose)
return tracker
if __name__ == '__main__':
# testing
os.environ['VERL_ENABLE_TRACKER'] = '1'
os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test'
@ray.remote
def process(iter):
data = {'obs': torch.randn(10, 20)}
dump_data(data, f'process_{iter}_obs')
ray.init()
output_lst = []
for i in range(10):
output_lst.append(process.remote(i))
out = ray.get(output_lst)
tracker = get_trajectory_tracker()
ray.get(tracker.wait_for_hdfs.remote())

28
verl/utils/distributed.py Normal file
View File

@@ -0,0 +1,28 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for distributed training."""
import os
def initialize_global_process_group(timeout_second=36000):
import torch.distributed
from datetime import timedelta
torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
return local_rank, rank, world_size

123
verl/utils/flops_counter.py Normal file
View File

@@ -0,0 +1,123 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PretrainedConfig, Qwen2Config, LlamaConfig
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
def get_device_flops(unit="T"):
def unit_convert(number, level):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
elif "L40" in device_name:
flops = 181.05e12
elif "L20" in device_name:
flops = 119.5e12
elif "H20" in device_name:
flops = 148e12
elif "910B" in device_name:
flops = 354e12
flops_unit = unit_convert(flops, unit)
return flops_unit
class FlopsCounter:
"""
Used to count mfu during training loop
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. "
f"MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
intermediate_size = self.config.intermediate_size
head_dim = hidden_size // num_attention_heads
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim
# non-attn per layer parm
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
mlp_N = hidden_size * intermediate_size * 3
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time):
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
Args:
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
delta_time (float): The time taken to process the batch, in seconds.
Returns:
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
promised_flops (float): The expected FLOPS of the current device.
"""
tokens_sum = sum(batch_seqlens)
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
promised_flops = get_device_flops()
return estimated_flops, promised_flops

88
verl/utils/fs.py Normal file
View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
"""File-system agnostic IO APIs"""
import os
import tempfile
import hashlib
from .hdfs_io import copy, makedirs, exists
__all__ = ["copy", "exists", "makedirs"]
_HDFS_PREFIX = "hdfs://"
def _is_non_local(path):
return path.startswith(_HDFS_PREFIX)
def md5_encode(path: str) -> str:
return hashlib.md5(path.encode()).hexdigest()
def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:
"""Return a local temp path that joins cache_dir and basename of hdfs_path
Args:
hdfs_path:
cache_dir:
Returns:
"""
# make a base64 encoding of hdfs_path to avoid directory conflict
encoded_hdfs_path = md5_encode(hdfs_path)
temp_dir = os.path.join(cache_dir, encoded_hdfs_path)
os.makedirs(temp_dir, exist_ok=True)
dst = os.path.join(temp_dir, os.path.basename(hdfs_path))
return dst
def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str:
"""Copy src from hdfs to local if src is on hdfs or directly return src.
If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if
the src name is the same between calls
Args:
src (str): a HDFS path of a local path
Returns:
a local path of the copied file
"""
from filelock import FileLock
assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}'
if _is_non_local(src):
# download from hdfs to local
if cache_dir is None:
# get a temp folder
cache_dir = tempfile.gettempdir()
os.makedirs(cache_dir, exist_ok=True)
assert os.path.exists(cache_dir)
local_path = get_local_temp_path(src, cache_dir)
# get a specific lock
filelock = md5_encode(src) + '.lock'
lock_file = os.path.join(cache_dir, filelock)
with FileLock(lock_file=lock_file):
if not os.path.exists(local_path):
if verbose:
print(f'Copy from {src} to {local_path}')
copy(src, local_path)
return local_path
else:
return src

329
verl/utils/fsdp_utils.py Normal file
View File

@@ -0,0 +1,329 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import functools
import json
import math
import itertools
import os
from contextlib import contextmanager
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from transformers.trainer_pt_utils import get_module_class_from_name
import torch
import torch.nn as nn
import torch.distributed as dist
def init_fn(x: torch.nn.Module):
if not torch.distributed.get_rank() == 0:
x = x.to_empty(device=torch.cuda.current_device(), recurse=False)
torch.cuda.empty_cache()
return x
def get_init_weight_context_manager(use_meta_tensor=True):
from accelerate import init_empty_weights
cpu_init_weights = lambda: torch.device('cpu')
if use_meta_tensor:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
else:
init_context = cpu_init_weights
return init_context
# Copyright 2020-present the HuggingFace Inc. team.
# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py
def get_fsdp_wrap_policy(module, config=None, is_lora=False):
"""Get FSDP wrap policy for the module.
Args:
module: The module to get wrap policy for
config: Configuration for wrap policy
is_lora: Whether to enable lambda policy for LoRA modules
"""
if config is None:
config = {}
if config.get('disable', False):
return None
default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap",
default_transformer_cls_names_to_wrap)
min_num_params = config.get('min_num_params', 0)
auto_wrap_policy = None
policies = []
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
# Add lambda policy for LoRA modules if is_lora is True
if is_lora:
def lambda_policy_fn(module):
if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and
module.weight.requires_grad):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
policies.append(lambda_policy)
if min_num_params > 0:
size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
policies.append(size_policy)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in fsdp_transformer_layer_cls_to_wrap:
transformer_cls = get_module_class_from_name(module, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
transformer_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_cls_to_wrap,
)
policies.append(transformer_policy)
if len(policies) > 0:
auto_wrap_policy = functools.partial(_or_policy, policies=policies)
return auto_wrap_policy
def offload_fsdp_grad(module):
for _, param in module.named_parameters():
if param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_grad(module, device_id):
for _, param in module.named_parameters():
if param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
def offload_fsdp_param_and_grad(module, offload_grad=False):
for _, param in module.named_parameters():
if hasattr(param, "_local_shard"):
param._local_shard = param._local_shard.to("cpu", non_blocking=True)
param.data = param.data.to('cpu', non_blocking=True)
if offload_grad and param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_param_and_grad(module, device_id, load_grad=False):
for _, param in module.named_parameters():
if hasattr(param, "_local_shard"):
param._local_shard = param._local_shard.to(device_id, non_blocking=True)
param.data = param.data.to(device_id, non_blocking=True)
if load_grad and param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
def offload_fsdp_optimizer(optimizer):
for param_group in optimizer.param_groups:
for param in param_group['params']:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_fsdp_optimizer(optimizer, device_id):
for param_group in optimizer.param_groups:
for param in param_group['params']:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
@contextmanager
def meta_device_init():
"""
Create model parameters with meta device.
Note buffers in model will still be initialized in default device (e.g., CPU),
since the buffers can be non-persistent and filled with expected values that can
NOT be captured in meta device.
"""
device = torch.device("meta")
old_register_parameter = nn.Module.register_parameter
registered = set()
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
# we will skip register shared parameters as it
# is already registered previously
if param is not None and param not in registered:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
registered.add(module._parameters[name])
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
registered.clear()
nn.Module.register_parameter = old_register_parameter
def parallel_load_safetensors(filepath):
"""
Parallel load safetensors from huggingface checkpoint
Huggingface checkpoint contains:
- config.json: a json file for model configuration
- model.safetensor.index.json: a json file for safetensors (parameters & buffers) index
- model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks
Or (when model is small),
- model.safetensors: a binary file for all parameters and buffers
Each rank will own a part of model chunks and load them directly into GPU memory.
"""
from safetensors.torch import load_file
safetensors2param = {}
index_file = os.path.join(filepath, "model.safetensors.index.json")
if os.path.exists(index_file):
index = json.load(open(index_file, "rb"))
for param_name, filename in index["weight_map"].items():
safetensors2param.setdefault(filename, []).append(param_name)
else:
# in this case, the model is small and we can load it all at once
param_file = os.path.join(filepath, "model.safetensors")
assert os.path.exists(param_file), f"Cannot find {param_file}"
states = load_file(param_file)
for param_name in states:
safetensors2param.setdefault("model.safetensors", []).append(param_name)
del states
total_files = len(safetensors2param)
ckpt_chunks = sorted(safetensors2param.keys())
world_size = dist.get_world_size()
size = int(math.ceil(total_files / world_size))
ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)]
shard_states = {}
device = torch.cuda.current_device()
for rank, files in enumerate(ckpt_chunks):
if rank == dist.get_rank():
for file in files:
file = os.path.join(filepath, file)
states = load_file(file, device=device)
# print(f"rank {rank} loading {file}...")
shard_states.update(states)
else:
for file in files:
for param_name in safetensors2param[file]:
shard_states[param_name] = rank
return shard_states
def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):
"""
Generate a function to initialize sub-modules in the `module` with `shard_states`
from huggingface checkpoint.
Args:
module (torch.nn.Module): the global module to be initialized
shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint
Returns:
init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`
"""
state2fqn = {}
for name, state in itertools.chain(module.named_parameters(remove_duplicate=False),
module.named_buffers(remove_duplicate=False)):
state2fqn.setdefault(state, []).append(name)
# remove standalone parameters and buffers
shared = {s for s, names in state2fqn.items() if len(names) > 1}
materialized_states = {}
@torch.no_grad()
def create_and_sync_state(param_name, state, is_param):
assert param_name in shard_states, f"{param_name} not loaded"
device = torch.cuda.current_device()
if is_param:
param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)
else: # buffer
param = torch.empty_like(state.data, device=device)
loaded = shard_states[param_name]
if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
# NOTE: loaded.dtype can be different with param.dtype
param.data.copy_(loaded.data)
dist.broadcast(param.data, src=dist.get_rank())
else:
assert isinstance(loaded, int) # the rank that holds the state
dist.broadcast(param.data, src=loaded)
shard_states.pop(param_name)
del loaded
return param
def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):
param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))
# param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])
for name, state in param_and_buffers:
if not state.is_meta:
continue
is_param = name in sub_mod._parameters
fqn = state2fqn[state].pop(0)
# non-persistent buffers will not be saved in state dict, we can safely skip it
if (not is_param) and fqn not in shard_states:
if state.is_meta:
raise RuntimeError(
f"find a non-persistent buffer ({fqn}) initiated with device meta. "
"Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.")
continue
# for shared parameter, we get it from the first time it is created
if state in shared:
if state not in materialized_states:
materialized_states[state] = create_and_sync_state(fqn, state, is_param)
else:
if fqn in shard_states:
shard_states.pop(fqn)
materialize_state = materialized_states[state]
# for not shared parameter, we create it directly
else:
materialize_state = create_and_sync_state(fqn, state, is_param)
if is_param:
sub_mod._parameters[name] = materialize_state
else:
sub_mod._buffers[name] = materialize_state
if recurse:
for module in sub_mod.children():
init_fn(module, recurse=True)
# for debug
# if len(shard_states) == 0: print("clear")
return sub_mod
return init_fn

144
verl/utils/hdfs_io.py Normal file
View File

@@ -0,0 +1,144 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import logging
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
_HDFS_PREFIX = "hdfs://"
_HDFS_BIN_PATH = shutil.which('hdfs')
def exists(path: str, **kwargs) -> bool:
r"""Works like os.path.exists() but supports hdfs.
Test whether a path exists. Returns False for broken symbolic links.
Args:
path (str): path to test
Returns:
bool: True if the path exists, False otherwise
"""
if _is_non_local(path):
return _exists(path, **kwargs)
return os.path.exists(path)
def _exists(file_path: str):
""" hdfs capable to check whether a file_path is exists """
if file_path.startswith("hdfs"):
return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0
return os.path.exists(file_path)
def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:
r"""Works like os.makedirs() but supports hdfs.
Super-mkdir; create a leaf directory and all intermediate ones. Works like
mkdir, except that any intermediate path segment (not just the rightmost)
will be created if it does not exist. If the target directory already
exists, raise an OSError if exist_ok is False. Otherwise no exception is
raised. This is recursive.
Args:
name (str): directory to create
mode (int): file mode bits
exist_ok (bool): if True, do not raise an exception if the directory already exists
kwargs: keyword arguments for hdfs
"""
if _is_non_local(name):
# TODO(haibin.lin):
# - handle OSError for hdfs(?)
# - support exist_ok for hdfs(?)
_mkdir(name, **kwargs)
else:
os.makedirs(name, mode=mode, exist_ok=exist_ok)
def _mkdir(file_path: str) -> bool:
"""hdfs mkdir"""
if file_path.startswith("hdfs"):
_run_cmd(_hdfs_cmd(f"-mkdir -p {file_path}"))
else:
os.makedirs(file_path, exist_ok=True)
return True
def copy(src: str, dst: str, **kwargs) -> bool:
r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs.
Copy data and mode bits ("cp src dst"). Return the file's destination.
The destination may be a directory.
If source and destination are the same file, a SameFileError will be
raised.
Arg:
src (str): source file path
dst (str): destination file path
kwargs: keyword arguments for hdfs copy
Returns:
str: destination file path
"""
if _is_non_local(src) or _is_non_local(dst):
# TODO(haibin.lin):
# - handle SameFileError for hdfs files(?)
# - return file destination for hdfs files
return _copy(src, dst)
else:
if os.path.isdir(src):
return shutil.copytree(src, dst, **kwargs)
else:
return shutil.copy(src, dst, **kwargs)
def _copy(from_path: str, to_path: str, timeout: int = None) -> bool:
if to_path.startswith("hdfs"):
if from_path.startswith("hdfs"):
returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout)
else:
returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout)
else:
if from_path.startswith("hdfs"):
returncode = _run_cmd(_hdfs_cmd(f"-get \
{from_path} {to_path}"), timeout=timeout)
else:
try:
shutil.copy(from_path, to_path)
returncode = 0
except shutil.SameFileError:
returncode = 0
except Exception as e:
logger.warning(f"copy {from_path} {to_path} failed: {e}")
returncode = -1
return returncode == 0
def _run_cmd(cmd: str, timeout=None):
return os.system(cmd)
def _hdfs_cmd(cmd: str) -> str:
return f"{_HDFS_BIN_PATH} dfs {cmd}"
def _is_non_local(path: str):
return path.startswith(_HDFS_PREFIX)

View File

@@ -0,0 +1,48 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to check if packages are available.
We assume package availability won't change during runtime.
"""
from functools import cache
from typing import List
@cache
def is_megatron_core_available():
try:
from megatron.core import parallel_state as mpu
return True
except ImportError:
return False
@cache
def is_vllm_available():
try:
import vllm
return True
except ImportError:
return False
def import_external_libs(external_libs=None):
if external_libs is None:
return
if not isinstance(external_libs, List):
external_libs = [external_libs]
import importlib
for external_lib in external_libs:
importlib.import_module(external_lib)

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,41 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class MemoryBuffer:
def __init__(self, numel, numel_padded, dtype):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(self.numel_padded,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor

View File

@@ -0,0 +1,92 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron.optimizer.distrib_optimizer import DistributedOptimizer
from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler
from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
from megatron.optimizer import get_param_groups
from verl.utils.megatron.optimizer_config import OptimizerConfig
def get_megatron_optimizer(
model,
config: OptimizerConfig,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0,
check_for_nan_in_loss_and_grad=False,
overlap_param_gather=False # add for verl
):
# Base optimizer.
param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult)
if config.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps)
elif config.optimizer == 'sgd':
optimizer = SGD(param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
# Determine whether the params have main-grad field.
params_have_main_grad = True
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
# Dynamic loss scale.
else:
if config.fp16:
grad_scaler = DynamicGradScaler(initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis)
# Megatron optimizer.
if config.use_distributed_optimizer:
return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16,
config.params_dtype, grad_scaler, model, overlap_param_gather)
else:
return Float16OptimizerWithFloat16Params(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16,
config.bf16, config.params_dtype, grad_scaler, model)
# FP32.
return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad,
params_have_main_grad, model)

View File

@@ -0,0 +1,129 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class OptimizerConfig:
"""Configuration for optimizer."""
##############
# General
##############
optimizer: str = 'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr: Optional[float] = None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr: Optional[float] = None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr: Optional[float] = None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr: Optional[float] = None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay: float = 0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16: bool = False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16: bool = False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale: Optional[float] = None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale: float = 2**32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale: float = 1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window: float = 1000
"""Window over which to raise/lower dynamic scale."""
hysteresis: int = 2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1: float = 0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2: float = 0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps: float = 1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum: float = 0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer: bool = False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce: bool = False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
################
# Miscellaneous
################
clip_grad: float = 1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad: bool = False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time: bool = False
"""If true, use barrier with level 1 time measurements."""
timers: Callable = None
"""Function to get timers."""

View File

@@ -0,0 +1,51 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron.core import parallel_state as mpu
from .sequence_parallel import pad_to_sequence_parallel
def compute_transformers_input_shapes(batches, meta_info):
from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
# pre-compute input shapes for each micro-batch at each pp stage
input_shapes = []
for model_inputs in batches:
input_ids = model_inputs['input_ids']
attention_mask = model_inputs['attention_mask']
input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
if meta_info['sequence_parallel']:
input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
# compute shapes for model_inputs
input_shapes.append(
torch.Size([
input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
]))
else:
# compute shapes for model_inputs
input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
return input_shapes
def make_batch_generator(batches, vpp_size):
if vpp_size > 1:
# has vpp
batch_generator = [batches] * vpp_size # number of vpp chunks
batch_generator = [iter(b) for b in batch_generator]
else:
# no vpp
batch_generator = iter(batches)
return batch_generator

View File

@@ -0,0 +1,54 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from megatron.core import parallel_state as mpu
def mark_parameter_as_sequence_parallel(parameter):
setattr(parameter, 'sequence_parallel', True)
def is_sequence_parallel_param(param):
return hasattr(param, 'sequence_parallel') and param.sequence_parallel
def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):
"""pad the tokens such that the total length is a multiple of sp world size
Args:
unpad_tokens: (total_nnz, ...). Tokens after removing padding
Returns:
"""
total_nnz = unpad_tokens.shape[0]
sp_world_size = mpu.get_tensor_model_parallel_world_size()
if total_nnz % sp_world_size == 0:
pad_size = 0
else:
pad_size = sp_world_size - total_nnz % sp_world_size
if pad_size > 0:
if unpad_tokens.ndim == 1:
unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
elif unpad_tokens.ndim == 2:
unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
else:
raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
return unpad_tokens

View File

@@ -0,0 +1,184 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for using tensor_parallel in megatron
"""
from typing import Dict
import torch
from torch.nn import init
import torch.distributed as dist
from megatron.core import ModelParallelConfig
from megatron.core import parallel_state as mpu, tensor_parallel
import verl.utils.torch_functional as verl_F
def update_kwargs_with_config(dictionary: Dict, config: ModelParallelConfig):
dictionary['config'] = config
return dictionary
def get_default_kwargs_for_model_parallel_config():
model_parallel_config_kwargs = {
'params_dtype': torch.float32,
'use_cpu_initialization': False,
'perform_initialization': True,
'gradient_accumulation_fusion': False,
'sequence_parallel': False,
}
return model_parallel_config_kwargs
def get_default_model_parallel_config():
return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())
def get_common_default_kwargs_for_parallel_linear():
default_model_parallel_config = get_default_model_parallel_config()
common_default_kwargs = {
'init_method': init.xavier_normal_,
'stride': 1,
'keep_master_weight_for_test': False,
'config': default_model_parallel_config,
}
return common_default_kwargs
def get_default_kwargs_for_column_parallel_linear():
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
column_parallel_config_kwargs = {
'async_tensor_model_parallel_allreduce': False,
}
model_parallel_config_kwargs.update(column_parallel_config_kwargs)
column_default_kwargs = {
'config': ModelParallelConfig(**model_parallel_config_kwargs),
}
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
common_default_kwargs.update(column_default_kwargs)
return common_default_kwargs
def get_default_kwargs_for_row_parallel_linear():
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
return common_default_kwargs
def get_default_kwargs_for_parallel_embedding():
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
embedding_default_kwargs = {
'init_method': init.xavier_normal_,
'config': ModelParallelConfig(**model_parallel_config_kwargs),
}
return embedding_default_kwargs
def is_tensor_parallel_param(param):
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel)
def get_tensor_parallel_partition_dim(param):
assert is_tensor_parallel_param(param)
return param.partition_dim
def get_tensor_parallel_partition_stride(param):
assert is_tensor_parallel_param(param)
return param.partition_stride
class _VocabParallelEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
normalized_exp_logits = normalized_vocab_parallel_logits.exp()
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())
softmax_logits = normalized_exp_logits / normalized_sum_exp_logits
sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True)
dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
return entropy.squeeze(dim=-1)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits)
return grad_input
def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
"""Compute entropy when the logits are sharded in tp ranks
Args:
vocab_parallel_logits: (total_nnz, vocab_size // tp_size)
Returns: (total_nnz,)
"""
return _VocabParallelEntropy.apply(vocab_parallel_logits)
def vocab_parallel_log_probs_from_logits(logits, labels):
"""TODO(zhangchi.usc1992): We may change the implementation later"""
return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)
def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
"""Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region.
This will further reduce the peak memory usage during training
Args:
input_ids: [batch_size, seqlen]
attention_mask: [batch_size, seqlen]
logits_rmpad: [total_nnz, vocab_size // tp_size]
response_length: int
"""
from flash_attn.bert_padding import pad_input, unpad_input
batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad,
labels=input_ids_rmpad_rolled) # (total_nnz,)
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
return output
def vocab_parallel_compute_entropy_loss(logits, eos_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = vocab_parallel_entropy(logits)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss

View File

@@ -0,0 +1,253 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities."""
from typing import Any, Dict
import time
from omegaconf import DictConfig
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.memory_buffer import build_memory_reference_from_module
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import Float16Module
# from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.get_virtual_pipeline_model_parallel_world_size() is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(pre_process=pre_process, post_process=post_process)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert mpu.get_pipeline_model_parallel_split_rank() is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = mpu.get_pipeline_model_parallel_split_rank()
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(pre_process=pre_process, post_process=post_process)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])),
flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
config = get_model_config(model[0])
if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel
model = [Float16Module(config, model_module) for model_module in model]
if wrap_with_ddp:
model = [
DDP(config=config,
module=model_chunk,
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
disable_bucketing=(model_chunk_idx > 0)) for (model_chunk_idx, model_chunk) in enumerate(model)
]
# # Broadcast params from data parallel src rank to other data parallel ranks.
# if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
from transformers import PretrainedConfig
def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
print(f'megatron config {megatron_config}')
dt = PrecisionType.to_dtype(megatron_config['param_dtype'])
print(f'pipeline_dtype=megatron_config {dt}')
transformer_config = TransformerConfig(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
num_attention_heads=hf_config.num_attention_heads,
num_query_groups=hf_config.num_key_value_heads,
ffn_hidden_size=hf_config.intermediate_size,
# max_position_embeddings=hf_config.max_position_embeddings,
activation_func=F.silu,
normalization='RMSNorm',
# rotary_percent=False, # default,
gated_linear_unit=True, # for llama
use_cpu_initialization=True,
apply_residual_connection_post_layernorm=False, # check what's this mean
add_bias_linear=False,
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
pipeline_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
params_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
sequence_parallel=megatron_config['sequence_parallel_enabled'],
variable_seq_lengths=True,
masked_softmax_fusion=True,
bf16=PrecisionType.to_dtype(megatron_config['param_dtype']) is torch.bfloat16)
if torch.distributed.get_rank() == 0:
print(f'tensor_parallel_size={transformer_config.tensor_model_parallel_size} \n \
pipeline_model_parallel_size={transformer_config.pipeline_model_parallel_size} \n \
virtual_pipeline_model_parallel_size={transformer_config.virtual_pipeline_model_parallel_size} \n \
pipeline_dtype={transformer_config.pipeline_dtype} \n \
params_dtype={transformer_config.params_dtype} \n \
sequence_parallel={transformer_config.sequence_parallel} \n \
variable_seq_lengths={transformer_config.variable_seq_lengths} \n \
masked_softmax_fusion={transformer_config.masked_softmax_fusion} \n ')
return transformer_config
# from megatron.core.optimizer import OptimizerConfig
from verl.utils.megatron.optimizer_config import OptimizerConfig
def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
config = OptimizerConfig(
optimizer='adam',
lr=optim_config.get('lr'),
clip_grad=optim_config.get('clip_grad'),
weight_decay=1e-2,
bf16=True,
params_dtype=torch.bfloat16,
use_distributed_optimizer=True,
)
return config
from megatron.core import ModelParallelConfig
def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
# TODO(sgm): check how to disable megatron timers
timers = FakeTimers()
return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'),
pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'),
virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'),
sequence_parallel=config.get('sequence_parallel'),
params_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
pipeline_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
bf16=True,
fp16=False,
timers=timers)
class FakeTimers:
"""Disable All Megatron Timing with FakeTimers"""
def __init__(self):
from megatron.timers import DummyTimer
self.dummy_timer = DummyTimer()
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.dummy_timer
def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None):
if hybrid_engine is not None:
pp_rank = mpu.get_pipeline_model_parallel_rank()
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
buffer.data = buffer.data.to('cpu', non_blocking=True)
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
else:
for module in module_list:
for _, param in module.named_parameters():
param.data = param.data.to('cpu', non_blocking=True)
if offload_grad and param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
def load_megatron_param_and_grad(module_list: nn.ModuleList, device_id, load_grad=False, hybrid_engine=None):
if hybrid_engine is not None:
pp_rank = mpu.get_pipeline_model_parallel_rank()
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
buffer.data = buffer.data.to(device_id, non_blocking=True)
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
else:
for module in module_list:
for _, param in module.named_parameters():
param.data = param.data.to(device_id, non_blocking=True)
if load_grad and param.grad is not None:
param.grad = param.grad.to(device_id, non_blocking=True)
torch.cuda.empty_cache()

214
verl/utils/memory_buffer.py Normal file
View File

@@ -0,0 +1,214 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains utilities to manipulate torch memory buffers
"""
from typing import Dict, List
import torch
from torch import nn
class MemoryBuffer:
"""
A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying
memory. It must have a unique type to support this behavior.
"""
def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def calc_padded_numel(shape: torch.Size, dtype: torch.dtype):
"""for cuda memory alignment, make sure alignment by 128-bits"""
align_numel = 128 // torch.finfo(dtype).bits
numel = shape.numel()
return (numel + align_numel - 1) // align_numel * align_numel
def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]:
"""
Return a dictionary containing name to a shape and dtype.
"""
weight_buffer_meta = {}
for name, param in sorted(module.named_parameters()):
weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype}
return weight_buffer_meta
def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]:
"""Build the memory buffer given weight_buffer_meta
Args:
weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors
Returns: a large memory buffer for each dtype that can hold all the tensors
"""
memory_buffers = {}
total_numel_map = {} # map from dtype to the total numel
for name, meta_info in sorted(weight_buffer_meta.items()):
shape = meta_info['shape']
dtype = meta_info['dtype']
assert isinstance(shape, torch.Size)
assert isinstance(dtype, torch.dtype)
if dtype not in total_numel_map:
total_numel_map[dtype] = 0
total_numel_map[dtype] += calc_padded_numel(shape, dtype)
for dtype, total_numel in total_numel_map.items():
memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)
return memory_buffers
def build_memory_reference_from_module(module: torch.nn.Module,
memory_buffers: Dict[torch.dtype, MemoryBuffer],
maintain_weight=True):
start_index = {}
for dtype in memory_buffers.keys():
start_index[dtype] = 0
for name, param in sorted(module.named_parameters()):
memory_buffer = memory_buffers[param.dtype]
buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])
# need to increment start_index
start_index[param.dtype] += calc_padded_numel(param.shape, dtype)
if maintain_weight:
buffer.copy_(param.data)
param.data = buffer
def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):
"""Build the memory references. The memory buffers are built using the build_memory_buffer API.
This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.
Args:
weight_buffer_meta:
memory_buffers:
Returns:
"""
start_idx = {}
weight_buffers = {}
for dtype in memory_buffers.keys():
start_idx[dtype] = 0
for name, meta_info in sorted(weight_buffer_meta.items()):
shape = meta_info['shape']
dtype = meta_info['dtype']
buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])
start_idx[dtype] += calc_padded_numel(shape, dtype)
weight_buffers[name] = buffer
return weight_buffers
class MemoryBufferModuleWrapper:
"""
Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to
- It will change the checkpoint name
"""
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)
self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)
build_memory_reference_from_module(self.module, self.memory_buffers)
def get_memory_buffers(self):
return self.memory_buffers
def get_weight_buffer_meta(self):
return self.weight_buffer_meta
class MegatronMemoryBufferForRollout(object):
"""
We assume that
- inference engine has tp + dp
- actor has tp + pp + dp
- the tp between inference engine and actor should be the same
- memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer
- weight_buffers: contains a list of weight_buffers, each is a dict from name to param
- named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that
the named_parameters may not be directly compatible with inference engine. User has to take care of
this part such as the layout mismatches. (e.g. qkv transpose)
- Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.
- When doing weight sync, the data is transfer via memory buffers
"""
def __init__(self, transform_memory_param_fn):
self._memory_buffers = []
self._weight_buffers = []
self._named_parameters = {}
self.transform_memory_param_fn = transform_memory_param_fn
def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]):
"""
Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct
a large buffer for each dtype in the weight_buffer.
Args:
weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from
Returns: None
"""
self.weight_buffer_meta_pp = weight_buffer_meta_pp
for weight_buffer_meta in self.weight_buffer_meta_pp:
memory_buffer = build_memory_buffer(weight_buffer_meta)
self._memory_buffers.append(memory_buffer)
self._weight_buffers.append(None)
def build_memory_reference(self):
for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):
self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])
self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)
@property
def named_parameters(self):
return self._named_parameters
@property
def weight_buffers(self):
return self._weight_buffers
@property
def memory_buffers(self):
return self._memory_buffers

332
verl/utils/model.py Normal file
View File

@@ -0,0 +1,332 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to create common models from huggingface
"""
import os
import warnings
from typing import Dict, Type
import numpy as np
import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification
from verl.models.registry import ModelRegistry
class LambdaLayer(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def squeeze(x):
return torch.squeeze(x, dim=-1)
def update_model_config(module_config, override_config_kwargs):
for key, val in override_config_kwargs.items():
setattr(module_config, key, val)
def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> Dict:
if override_config_kwargs is None:
override_config_kwargs = {}
assert isinstance(override_config_kwargs, Dict), \
f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}'
module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
update_model_config(module_config, override_config_kwargs)
return module_config
def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
"""
Args:
model_name:
actor_override_config_kwargs:
Returns:
"""
if override_config_kwargs is None:
override_config_kwargs = {}
if automodel_kwargs is None:
automodel_kwargs = {}
assert isinstance(override_config_kwargs, Dict), \
f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}'
module_config = get_huggingface_actor_config(model_name,
override_config_kwargs,
trust_remote_code=automodel_kwargs.get('trust_remote_code', False))
module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)
return module
def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
"""
Args:
model_name:
override_config_kwargs:
Returns:
"""
critic_module: nn.Module = create_huggingface_actor(model_name,
override_config_kwargs=override_config_kwargs,
automodel_kwargs=automodel_kwargs)
if automodel_kwargs is None:
automodel_kwargs = {}
torch_dtype = automodel_kwargs.get('torch_dtype', torch.float32)
critic_module.lm_head = nn.Sequential(nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype),
LambdaLayer(fn=squeeze))
return critic_module
def get_model_size(model: nn.Module, scale='auto'):
n_params = sum(p.numel() for p in model.parameters())
if scale == 'auto':
if n_params > 1e9:
scale = 'B'
elif n_params > 1e6:
scale = 'M'
elif n_params > 1e3:
scale = 'K'
else:
scale = ''
if scale == 'B':
n_params = n_params / 1e9
elif scale == 'M':
n_params = n_params / 1e6
elif scale == 'K':
n_params = n_params / 1e3
elif scale == '':
pass
else:
raise NotImplemented(f'Unknown scale {scale}')
return n_params, scale
def print_model_size(model: nn.Module, name: str = None):
n_params, scale = get_model_size(model, scale='auto')
if name is None:
name = model.__class__.__name__
print(f'{name} contains {n_params:.2f}{scale} parameters')
def create_random_mask(input_ids: torch.Tensor,
max_ratio_of_valid_token: float,
max_ratio_of_left_padding: float,
min_ratio_of_valid_token: float = 0):
"""Create a random mask given input_ids. Support left padding and right padding.
Process:
- Sample valid token length
- Sample left_padding length
- Generate padding
Args:
input_ids:
shape (batch_size, seq_len)
Returns:
"""
assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.
assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.
assert min_ratio_of_valid_token <= max_ratio_of_valid_token
batch_size, sequence_length = input_ids.shape
max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)
min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))
max_left_padding = int(sequence_length * max_ratio_of_left_padding)
assert max_num_valid_tokens + max_left_padding <= sequence_length
assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length
masks = torch.ones_like(input_ids, dtype=torch.int64)
# TODO: we can make this faster
for i in range(batch_size):
num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)
num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)
for index in range(num_left_padding):
masks[i, index] = 0
for index in range(num_left_padding + num_valid, sequence_length):
masks[i, index] = 0
return masks
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
def normalize_pp_vpp_params(params, num_hidden_layers, layer_name='layers'):
"""
Normalize the pp vpp params into a complete named parameters.
This is useful when gather parameters from pp ranks and passed to a model without pp
params: List[List[Dict[str, param]]]
params contains a list of pp, with a list of vpp named_parameters in each vpp chunk.
output: Dict[str, param]
"""
def normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_layers):
"""
Transform the model name in each model_chunk in each pp stage into the name in inference engine
"""
if vpp_size > 1:
# print(f'try to bind vpp params to inference engine...')
layers_per_pp = num_layers // pp_size
layers_per_vpp = layers_per_pp // vpp_size
pp_offset = layers_per_vpp * pp_rank
vpp_offset = (layers_per_vpp * pp_size) * vpp_rank
layer_offset = pp_offset + vpp_offset
else:
layers_per_pp = num_layers // pp_size
layer_offset = layers_per_pp * pp_rank
if layer_name in name: # belong to an intermediate layer
split_name = name.split('.')
# find the num next to split_name
for i, name in enumerate(split_name):
if name == layer_name:
break
layer_num_idx = i + 1
# check the name
assert len(split_name) >= layer_num_idx + 1, f'split_name = {split_name}'
assert split_name[layer_num_idx].isdigit(), f'split_name = {split_name}'
# increment layer_num_idx by layer_offset
split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)
name = '.'.join(split_name) # weight name in inference_tp_model
return name
pp_size = len(params)
normalized_name_to_param = {}
for pp_rank in range(len(params)):
vpp_size = len(params[pp_rank])
for vpp_rank in range(vpp_size):
for name, param in params[pp_rank][vpp_rank].items():
normalized_name = normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers)
normalized_name_to_param[normalized_name] = param
return normalized_name_to_param
def get_parallel_model_from_config(config, megatron_config, pre_process=None, post_process=None, value=False):
from megatron.core import ModelParallelConfig
assert isinstance(megatron_config, ModelParallelConfig)
model_class = _get_parallel_model_architecture_from_config(config, value)
model = model_class(config, megatron_config, pre_process=pre_process, post_process=post_process)
return model
def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch, value)
if model_cls is not None:
return model_cls
raise ValueError(f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def load_megatron_model_weights(config,
model_config,
parallel_model,
params_dtype,
is_value_model=False,
local_cache_path='~/.cache/verl/rlhf'):
assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!"
architectures = getattr(model_config, "architectures", [])
local_cache_path = os.path.expanduser(local_cache_path)
if config.model.path.startswith("hdfs:"):
from verl.utils.fs import copy_local_path_from_hdfs
print(f'start download from {config.model.path}')
local_model_path = copy_local_path_from_hdfs(src=config.model.path, cache_dir=local_cache_path)
print('finish download')
else:
print(f"load from local dir {config.model.path}")
local_model_path = config.model.path
# TODO: to find a better way to load mistral7b-rm lm_head
if 'mistral7b-rm' in config.model.path:
model = MistralForSequenceClassification.from_pretrained(local_model_path) # use score head instead of lm_head
state_dict = model.state_dict()
state_dict['lm_head.weight'] = state_dict['score.weight']
state_dict['model.embed_tokens.weight'] = state_dict[
'model.embed_tokens.weight'][:32000] # workaround, 32001 -> 32000
is_value_model = True
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = AutoModelForCausalLM.from_pretrained(local_model_path)
state_dict = model.state_dict()
from verl.models.weight_loader_registry import get_weight_loader
print(f'before weight loader: architectures = {architectures}...')
for arch in architectures:
print(f'call weight loader arch = {arch}, model config = {model.config}')
weight_loader = get_weight_loader(arch)
weight_loader(state_dict=state_dict,
wrapped_models=parallel_model,
config=model.config,
params_dtype=params_dtype,
is_value_model=is_value_model)
# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp
def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):
"""pad the tokens such that the total length is a multiple of size.
This function is useful when applying sequence parallel and context parallel
Args:
unpad_tokens: (total_nnz, ...). Tokens after removing padding
cu_seqlens: (total_nnz + 1,)
max_seqlen_in_batch: int
Returns:
"""
F = nn.functional
total_nnz = unpad_tokens.shape[0]
if total_nnz % size == 0:
pad_size = 0
else:
pad_size = size - total_nnz % size
# we assume adding a new data in the batch with seqlen pad_size
if pad_size > 0:
if unpad_tokens.ndim == 1:
unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
elif unpad_tokens.ndim == 2:
unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
else:
raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])
max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)
return unpad_tokens, cu_seqlens, max_seqlen_in_batch

View File

@@ -0,0 +1,56 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contain small python utility functions
"""
from typing import Dict
from types import SimpleNamespace
def union_two_dict(dict1: Dict, dict2: Dict):
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
Args:
dict1:
dict2:
Returns:
"""
for key, val in dict2.items():
if key in dict1:
assert dict2[key] == dict1[key], \
f'{key} in meta_dict1 and meta_dict2 are not the same object'
dict1[key] = val
return dict1
def append_to_dict(data: Dict, new_data: Dict):
for key, val in new_data.items():
if key not in data:
data[key] = []
data[key].append(val)
class NestedNamespace(SimpleNamespace):
def __init__(self, dictionary, **kwargs):
super().__init__(**kwargs)
for key, value in dictionary.items():
if isinstance(value, dict):
self.__setattr__(key, NestedNamespace(value))
else:
self.__setattr__(key, value)

43
verl/utils/ray_utils.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains commonly used utilities for ray
"""
import ray
import concurrent.futures
def parallel_put(data_list, max_workers=None):
def put_data(index, data):
return index, ray.put(data)
if max_workers is None:
max_workers = min(len(data_list), 16)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]
res_lst = []
for future in concurrent.futures.as_completed(data_list_f):
res_lst.append(future.result())
# reorder based on index
output = [None for _ in range(len(data_list))]
for res in res_lst:
index, data_ref = res
output[index] = data_ref
return output

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,77 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from cupy.cuda.nccl import NcclCommunicator, get_unique_id
import ray
from ray.util import list_named_actors
@ray.remote
class NCCLIDStore:
def __init__(self, nccl_id):
self._nccl_id = nccl_id
def get(self):
return self._nccl_id
def get_nccl_id_store_by_name(name):
all_actors = list_named_actors(all_namespaces=True)
matched_actors = [actor for actor in all_actors if actor.get("name", None) == name]
if len(matched_actors) == 1:
actor = matched_actors[0]
return ray.get_actor(**actor)
elif len(matched_actors) > 1:
logging.warning(f"multiple actors with same name found: {matched_actors}")
elif len(matched_actors) == 0:
logging.info(f"failed to get any actor named {name}")
return None
def create_nccl_communicator_in_ray(rank: int,
world_size: int,
group_name: str,
max_retries: int = 100,
interval_s: int = 5):
if rank == 0:
nccl_id = get_unique_id()
nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)
assert ray.get(nccl_id_store.get.remote()) == nccl_id
communicator = NcclCommunicator(
ndev=world_size,
commId=nccl_id,
rank=0,
)
return communicator
else:
for i in range(max_retries):
nccl_id_store = get_nccl_id_store_by_name(group_name)
if nccl_id_store is not None:
logging.info(f"nccl_id_store {group_name} got")
nccl_id = ray.get(nccl_id_store.get.remote())
logging.info(f"nccl id for {group_name} got: {nccl_id}")
communicator = NcclCommunicator(
ndev=world_size,
commId=nccl_id,
rank=rank,
)
return communicator
logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds")
time.sleep(interval_s)

View File

@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,111 @@
import re
import random
import ast
import operator
def extract_solution(solution_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
if "Assistant:" in solution_str:
solution_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
else:
return None
solution_str = solution_str.split('\n')[-1]
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
else:
final_answer = None
return final_answer
def validate_equation(equation_str, available_numbers):
"""Validate that equation only uses available numbers and each number once."""
try:
# Extract all numbers from the equation
numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)]
# Check if all numbers in equation are available
available_numbers = sorted(available_numbers)
numbers_in_eq = sorted(numbers_in_eq)
# Each number should be used exactly once
return numbers_in_eq == available_numbers
except:
return False
def evaluate_equation(equation_str):
"""Safely evaluate the arithmetic equation using eval() with precautions."""
try:
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r'^[\d+\-*/().\s]+$'
if not re.match(allowed_pattern, equation_str):
raise ValueError("Invalid characters in equation.")
# Evaluate the equation with restricted globals and locals
result = eval(equation_str, {"__builtins__": None}, {})
return result
except Exception as e:
return None
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
"""The scoring function for countdown task.
Args:
solution_str: the solution text
ground_truth: dictionary containing target number and available numbers
method: the method to extract the solution
format_score: the score for correct format but wrong answer
score: the score for the correct answer
"""
target = ground_truth['target']
numbers = ground_truth['numbers']
equation = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Target: {target} | Numbers: {numbers}")
print(f"Extracted equation: {equation}")
print(f"Solution string: {solution_str}")
if equation is None:
if do_print:
print(f"No equation found")
return 0
# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
if do_print:
print(f"Invalid equation")
return format_score
# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
if do_print:
print(f"Could not evaluate equation")
return format_score
if abs(result - target) < 1e-5: # Account for floating point precision
if do_print:
print(f"Correct equation: {equation} = {result}")
return score
else:
if do_print:
print(f"Wrong result: equation = {result}, target = {target}")
return format_score
except:
if do_print:
print(f"Error evaluating equation")
return format_score

View File

@@ -0,0 +1,63 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def extract_solution(solution_str, method='strict'):
assert method in ['strict', 'flexible']
if method == 'strict':
# this also tests the formatting of the model
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None:
final_answer = None
else:
final_answer = solution.group(0)
final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
elif method == 'flexible':
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ['', '.']
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str, method=method)
if answer is None:
return 0
else:
if answer == ground_truth:
return score
else:
return format_score

View File

@@ -0,0 +1,227 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
def compute_score(solution_str, ground_truth) -> float:
retval = 0.
try:
string_in_last_boxed = last_boxed_only_string(solution_str)
if string_in_last_boxed is not None:
answer = remove_boxed(string_in_last_boxed)
if is_equiv(answer, ground_truth):
retval = 1.
except Exception as e:
print(e)
return retval
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[:len(left)] == left
return s[len(left):]
left = "\\boxed{"
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string

View File

@@ -0,0 +1,58 @@
import re
import random
def extract_solution(solution_str):
# Remove everything before the first "Assistant:"
if "Assistant:" in solution_str:
solution_str = solution_str.split("Assistant:", 1)[1]
else:
return None
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
else:
final_answer = None
if final_answer is not None:
try:
int_final_answer = int(final_answer)
except ValueError:
final_answer = None
return final_answer
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Ground truth: {ground_truth} | Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
if do_print:
print(f"No answer found")
return 0
else:
if int(answer) == int(ground_truth):
if do_print:
print(f"Correct answer: {answer}")
return score
else:
if do_print:
print(f"Incorrect answer {answer} | Ground truth: {ground_truth}")
return format_score

View File

@@ -0,0 +1,138 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import string
import random
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def em_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer == normalized_prediction:
score = 1
break
return score
def subem_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer in normalized_prediction:
score = 1
break
return score
def extract_solution(solution_str):
"""Extract the equation from the solution string."""
# Remove everything before the first "Assistant:"
# if "Assistant:" in solution_str:
# solution_str = solution_str.split("Assistant:", 1)[1]
# elif "<|im_start|>assistant" in solution_str:
# solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
# else:
# return None
# solution_str = solution_str.split('\n')[-1]
answer_pattern = r'<answer>(.*?)</answer>'
match = re.finditer(answer_pattern, solution_str, re.DOTALL)
matches = list(match)
# If there are 0 or exactly 1 matches, return None
if len(matches) <= 1:
return None
# If there are 2 or more matches, return the last one
return matches[-1].group(1).strip()
def compute_score_em(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for exact match (EM).
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth['target']}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
return 0
else:
if em_check(answer, ground_truth['target']):
return score
else:
return format_score
def compute_score_subem(solution_str, ground_truth, method='strict', format_score=0., score=1.):
"""The scoring function for substring exact match (EM).
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1
if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth['target']}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")
if answer is None:
return 0
else:
if subem_check(answer, ground_truth['target']):
return score
else:
return format_score

View File

@@ -0,0 +1,265 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Callable
import heapq
import torch
from torch import distributed as dist
from tensordict import TensorDict
import copy
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
def __init__(self) -> None:
self.sum = 0
self.items = []
def add(self, idx: int, val: int):
self.items.append((idx, val))
self.sum += val
def merge(self, other):
for idx, val in other.items:
self.items.append((idx, val))
self.sum += val
def __lt__(self, other):
if self.sum != other.sum:
return self.sum < other.sum
if len(self.items) != len(other.items):
return len(self.items) < len(other.items)
return self.items < other.items
class State:
def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
self.k = k
# sets should always be decreasing order
self.sets = [Set() for _ in range(k)]
assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
for i, (idx, seqlen) in enumerate(items):
self.sets[i].add(idx=idx, val=seqlen)
self.sets = sorted(self.sets, reverse=True)
def spread(self):
return self.sets[0].sum - self.sets[-1].sum
def get_partitions(self):
partitions = []
for i in range(len(self.sets)):
cur_partition = []
for idx, _ in self.sets[i].items:
cur_partition.append(idx)
partitions.append(cur_partition)
return partitions
def merge(self, other):
for i in range(self.k):
self.sets[i].merge(other.sets[self.k - 1 - i])
self.sets = sorted(self.sets, reverse=True)
@property
def spread(self) -> int:
return self.sets[0].sum - self.sets[-1].sum
def __lt__(self, other):
# least heap, let the state with largest spread to be popped first,
# if the spread is the same, let the state who has the largest set
# to be popped first.
if self.spread != other.spread:
return self.spread > other.spread
return self.sets[0] > other.sets[0]
def __repr__(self) -> str:
repr_str = "["
for i in range(self.k):
if i > 0:
repr_str += ","
repr_str += "{"
for j, (_, seqlen) in enumerate(self.sets[i].items):
if j > 0:
repr_str += ","
repr_str += str(seqlen)
repr_str += "}"
repr_str += "]"
return repr_str
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
states_pq = []
if equal_size:
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
for offset in range(0, len(sorted_seqlen_list), k_partitions):
items = []
for i in range(k_partitions):
seqlen, idx = sorted_seqlen_list[offset + i]
items.append((idx, seqlen))
heapq.heappush(states_pq, State(items=items, k=k_partitions))
else:
for seqlen, idx in sorted_seqlen_list:
heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
while len(states_pq) > 1:
state0 = heapq.heappop(states_pq)
state1 = heapq.heappop(states_pq)
# merge states
state0.merge(state1)
heapq.heappush(states_pq, state0)
final_state = states_pq[0]
partitions = final_state.get_partitions()
if equal_size:
for i, partition in enumerate(partitions):
assert len(partition) * \
k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
return partitions
def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
bias = sum(seqlen_list) + 1 if equal_size else 0
sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
partitions = [[] for _ in range(k_partitions)]
partition_sums = [0 for _ in range(k_partitions)]
for seqlen, i in sorted_seqlen:
min_idx = None
for j in range(k_partitions):
if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
min_idx = j
partitions[min_idx].append(i)
partition_sums[min_idx] += seqlen
if equal_size:
for i, partition in enumerate(partitions):
assert len(partition) * \
k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
return partitions
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
""" get order of seq lengths to make partitions balanced, this is
used in balacing sum of seqlength across dp ranks and microbatches
Parameters:
seqlen_list (List[int]):
seq lengths of each items
k_partitions (int):
resulting number of partitions
equal_size (bool):
if True, number of items in each partitions must be equal.
if False, only consider balancing the sum, each partition can have
variable number of items
Returns:
partitions (List[List[int]]):
return k_partitions list containing the index of items.
"""
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
def _check_and_sort_partitions(partitions):
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
seen_idx = set()
sorted_partitions = [None] * k_partitions
for i, partition in enumerate(partitions):
assert len(partition) > 0, f"the {i}-th partition is empty"
for idx in partition:
seen_idx.add(idx)
sorted_partitions[i] = sorted(partition)
assert seen_idx == set(range(len(seqlen_list)))
return sorted_partitions
partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
return _check_and_sort_partitions(partitions)
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
# add some metrics of seqlen sum on dp ranks
k_partition = len(partitions)
# assert len(seqlen_list) % k_partition == 0
batch_size = len(seqlen_list) // k_partition
min_sum_seqlen = None
max_sum_seqlen = None
total_sum_seqlen = 0
for offset in range(0, len(seqlen_list), batch_size):
cur_sum_seqlen = sum(seqlen_list[offset:offset + batch_size])
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
min_sum_seqlen = cur_sum_seqlen
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
max_sum_seqlen = cur_sum_seqlen
total_sum_seqlen += cur_sum_seqlen
balanced_sum_seqlen_list = []
for partition in partitions:
cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
# print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
return {
f'{prefix}/min': min_sum_seqlen,
f'{prefix}/max': max_sum_seqlen,
f'{prefix}/minmax_diff': max_sum_seqlen - min_sum_seqlen,
f'{prefix}/balanced_min': min_sum_seqlen_balanced,
f'{prefix}/balanced_max': max_sum_seqlen_balanced,
f'{prefix}/mean': total_sum_seqlen / len(partitions)
}
def ceildiv(a, b):
return -(a // -b)
def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
and the number of valid tokens in each micro batch is well balanced.
"""
# this is per local micro_bsz
max_seq_len = batch['attention_mask'].shape[-1]
assert max_token_len >= max_seq_len, \
f'max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}'
seq_len_effective: torch.Tensor = batch['attention_mask'].sum(dim=1)
total_seqlen = seq_len_effective.sum().item()
num_micro_batches = ceildiv(total_seqlen, max_token_len)
if dist.is_initialized():
num_micro_batches = torch.tensor([num_micro_batches], device='cuda')
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
num_micro_batches = num_micro_batches.cpu().item()
seq_len_effective = seq_len_effective.tolist()
assert num_micro_batches <= len(seq_len_effective)
micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
micro_batches = []
for partition in micro_bsz_idx:
curr_micro_batch = []
for idx in partition:
curr_micro_batch.append(batch[idx:idx + 1])
curr_micro_batch = torch.cat(curr_micro_batch)
micro_batches.append(curr_micro_batch)
return micro_batches, micro_bsz_idx
def get_reverse_idx(idx_map):
reverse_idx_map = copy.deepcopy(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map

58
verl/utils/tokenizer.py Normal file
View File

@@ -0,0 +1,58 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for tokenization."""
import warnings
__all__ = ['hf_tokenizer']
def set_pad_token_id(tokenizer):
"""Set pad_token_id to eos_token_id if it is None.
Args:
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.
"""
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}')
def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):
"""Create a huggingface pretrained tokenizer.
Args:
name (str): The name of the tokenizer.
correct_pad_token (bool): Whether to correct the pad token id.
correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.
**kwargs: The keyword arguments for the tokenizer.
Returns:
transformers.PreTrainedTokenizer: The pretrained tokenizer.
"""
from transformers import AutoTokenizer
if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path:
# the EOS token in gemma2 is ambiguious, which may worsen RL performance.
# https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to <end_of_turn> and 107.')
kwargs['eos_token'] = '<end_of_turn>'
kwargs['eos_token_id'] = 107
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
if correct_pad_token:
set_pad_token_id(tokenizer)
return tokenizer

View File

@@ -0,0 +1,82 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from Cruise.
"""
import torch
from typing import Union
HALF_LIST = [16, "16", "fp16", "float16"]
FLOAT_LIST = [32, "32", "fp32", "float32"]
BFLOAT_LIST = ["bf16", "bfloat16"]
class PrecisionType(object):
"""Type of precision used.
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF = "16"
FLOAT = "32"
FULL = "64"
BFLOAT = "bf16"
MIXED = "mixed"
@staticmethod
def supported_type(precision: Union[str, int]) -> bool:
return any(x == precision for x in PrecisionType)
@staticmethod
def supported_types() -> list[str]:
return [x.value for x in PrecisionType]
@staticmethod
def is_fp16(precision):
return precision in HALF_LIST
@staticmethod
def is_fp32(precision):
return precision in FLOAT_LIST
@staticmethod
def is_bf16(precision):
return precision in BFLOAT_LIST
@staticmethod
def to_dtype(precision):
if precision in HALF_LIST:
return torch.float16
elif precision in FLOAT_LIST:
return torch.float32
elif precision in BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"unexpected precision: {precision}")
@staticmethod
def to_str(precision):
if precision == torch.float16:
return 'fp16'
elif precision == torch.float32:
return 'fp32'
elif precision == torch.bfloat16:
return 'bf16'
else:
raise RuntimeError(f"unexpected precision: {precision}")

View File

@@ -0,0 +1,492 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contain small torch utilities
"""
from typing import Dict, Union, List, Optional
import os
import torch
import torch.distributed
import torch.nn.functional as F
from tensordict import TensorDict
from torch import nn
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True
except ImportError:
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False
def gather_from_labels(data, label):
"""Gather the label from data. The value in label should be [0, vocab_size)
Args:
data: (..., vocab_size)
label (torch.IntTensor) : (...,)
Returns:
"""
output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)
return output
def logprobs_from_logits(logits, labels):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
logits = logits.reshape(-1, last_dim)
labels = labels.reshape(-1)
output = logprobs_from_logits_flash_attn(logits, labels)
output = output.view(*batch_dim)
else:
output = logprobs_from_logits_naive(logits, labels)
return output
def logprobs_from_logits_flash_attn(logits, labels):
output = -cross_entropy_loss(logits, labels)[0]
return output
def logprobs_from_logits_naive(logits, labels):
logp = F.log_softmax(logits, dim=-1)
logpy = gather_from_labels(logp, labels)
return logpy
def logprobs_of_labels_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
assert logits.dtype == torch.float32, 'Using bf16 logits with logprobs_of_labels_v2 may lead to divergence'
logprobs_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1))
logprobs_labels = logprobs_labels - torch.logsumexp(logits, dim=-1, keepdim=True)
return logprobs_labels.squeeze(-1)
def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy
def masked_sum(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis)
def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("At least one element in the mask has to be 1.")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if mask_sum == 1:
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
'''
e.g. end of sentence token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
'''
eos_mask = response_id.eq(eos_token).long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
def compute_grad_norm(model: nn.Module):
total_grad_square = 0
total_params = 0
for param in model.parameters():
if param.grad is not None:
total_grad_square += torch.sum(torch.square(param.grad.detach())).item()
return total_grad_square
def broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], TensorDict], src, group):
"""
TODO: optimize this. Technically, we only need one broadcast
"""
for key in tensors.sorted_keys:
torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)
def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], size, group, dim=0):
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
Args:
tensors:
size:
group:
Returns:
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
tensors_as_dict = tensors.to_dict()
else:
tensors_as_dict = tensors
is_tensor_dict = False
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict:
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
return output
def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]:
assert tensors.batch_size[0] % batch_size == 0, \
f'input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}'
return tensors.split(batch_size)
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if tensors.shape[-1] >= max_seq_len:
return tensors
pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])
return F.pad(tensors, pad_tuple, 'constant', pad_token_id)
from transformers import PreTrainedTokenizer
def tokenize_and_postprocess_data(prompt: str,
tokenizer: PreTrainedTokenizer,
max_length: int,
pad_token_id: int,
left_pad=True,
truncation='error'):
"""
input_data is the output from tokenizer.
"""
assert truncation in ['left', 'right', 'error']
input_data = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = input_data['input_ids']
attention_mask = input_data['attention_mask']
assert input_ids.ndim == 2
sequence_length = input_ids.shape[-1]
if sequence_length < max_length:
input_ids = pad_sequence_to_length(input_ids,
max_seq_len=max_length,
pad_token_id=pad_token_id,
left_pad=left_pad)
attention_mask = pad_sequence_to_length(attention_mask,
max_seq_len=max_length,
pad_token_id=0,
left_pad=left_pad)
elif sequence_length > max_length:
if truncation == 'left':
# actually, left truncation may not be reasonable
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
elif truncation == 'right':
input_ids = input_ids[:, :max_length]
attention_mask = attention_mask[:, :max_length]
elif truncation == 'error':
raise NotImplementedError(f'{sequence_length=} is larger than {max_length=}')
else:
raise NotImplementedError(f'Unknown truncation method {truncation}')
return input_ids, attention_mask
def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
""" Remove the pad token.
Args:
input_ids shape: [bs, seq_length]
attention_mask shape: [bs, seq_length]
Returns:
no_padding_batch(List[List[int]]): contains the rmpad token ids per query.
"""
no_padding_batch = []
for ids, mask in zip(input_ids, attention_mask):
no_padding_batch.append((ids[len(ids) - mask.sum():]).cpu().numpy().tolist())
return no_padding_batch
def log_probs_from_logits_response(input_ids, logits, response_length):
"""Compute the response log_probs from full logits. Note that logits = model(input_ids)
Args:
input_ids: [batch_size, seqlen]
logits: [batch_size, seqlen, vocab_size]
Returns:
response_log_prob:
"""
response_logits = logits[:, -response_length - 1:-1]
response = input_ids[:, -response_length:]
response_log_prob = logprobs_from_logits(logits=response_logits, labels=response)
return response_log_prob
def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
"""Compute the log_probs from logits with rmpad logits and pad input. Note that
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
for large vocab_size
Args:
input_ids: [batch_size, seqlen]
attention_mask: [batch_size, seqlen]
logits_rmpad: [total_nnz, vocab_size]
response_length: int
"""
from flash_attn.bert_padding import pad_input, unpad_input
batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
return output
def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
for large vocab_size
Args:
input_ids_rmpad: [1, total_nnz]
logits_rmpad: [total_nnz, vocab_size]
indices: [total_nnz]
batch_size: int
seqlen: int
response_length: int
"""
from flash_attn.bert_padding import pad_input
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1]
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
return output
from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)
def post_process_logits(input_ids, logits, temperature, top_k, top_p):
if temperature != 1.:
logits = logits.div_(temperature) # inplace operation to avoid OOM
# TODO: add them back
# if top_k is not None and top_k > 0:
# logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits)
# if top_p is not None and top_p < 1.0 and top_p > 0.0:
# logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits)
return logits
"""
Optimizer related
"""
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
import math
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
last_epoch: int = -1,
):
def lr_lambda(current_step):
return min(1, float(current_step) / float(max(1, num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)

103
verl/utils/tracking.py Normal file
View File

@@ -0,0 +1,103 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import dataclasses
from enum import Enum
from functools import partial
from pathlib import Path
from typing import List, Union, Dict, Any
class Tracking(object):
supported_backend = ['wandb', 'mlflow', 'console']
def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None):
if isinstance(default_backend, str):
default_backend = [default_backend]
for backend in default_backend:
if backend == 'tracking':
import warnings
warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning)
else:
assert backend in self.supported_backend, f'{backend} is not supported'
self.logger = {}
if 'tracking' in default_backend or 'wandb' in default_backend:
import wandb
import os
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
if WANDB_API_KEY:
wandb.login(key=WANDB_API_KEY)
wandb.init(project=project_name, name=experiment_name, config=config)
self.logger['wandb'] = wandb
if 'mlflow' in default_backend:
import mlflow
mlflow.start_run(run_name=experiment_name)
mlflow.log_params(_compute_mlflow_params_from_objects(config))
self.logger['mlflow'] = _MlflowLoggingAdapter()
if 'console' in default_backend:
from verl.utils.logger.aggregate_logger import LocalLogger
self.console_logger = LocalLogger(print_to_console=True)
self.logger['console'] = self.console_logger
def log(self, data, step, backend=None):
for default_backend, logger_instance in self.logger.items():
if backend is None or default_backend in backend:
logger_instance.log(data=data, step=step)
class _MlflowLoggingAdapter:
def log(self, data, step):
import mlflow
mlflow.log_metrics(metrics=data, step=step)
def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:
if params is None:
return {}
return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/')
def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
_transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)
if dataclasses.is_dataclass(x):
return _transform(dataclasses.asdict(x))
if isinstance(x, dict):
return {k: _transform(v) for k, v in x.items()}
if isinstance(x, list):
if convert_list_to_dict:
return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)}
else:
return [_transform(v) for v in x]
if isinstance(x, Path):
return str(x)
if isinstance(x, Enum):
return x.value
return x
def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:
import pandas as pd
ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0]
assert isinstance(ans, dict)
return ans

288
verl/utils/ulysses.py Normal file
View File

@@ -0,0 +1,288 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for DeepSpeed Ulysses Sequence Parallelism.
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509
Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
"""
from typing import Any, Optional, List, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
"""
Set ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel world size.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_world_size(group) if group else 1
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel rank.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_rank(group) if group else 0
def gather_seq_scatter_heads(
x: Tensor,
seq_dim: int,
head_dim: int,
unpadded_dim_size: int = 0,
group: ProcessGroup = None,
) -> Tensor:
"""
A func to sync embedding input with alltoall in sequence parallel
gather sequence dimension and scatter head dim:
e.g. seq_dim: 1, head_dim: 2
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
sp_world = get_ulysses_sequence_parallel_world_size(group)
x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
padding_size = x.size(seq_dim) - unpadded_dim_size
x = _unpad_tensor(x, seq_dim, padding_size)
return x
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
"""
A func to sync attention result with alltoall in sequence parallel
gather head dimension and scatter seq dim:
e.g. seq_dim: 1, head_dim: 2
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
dim_size = x.size(seq_dim)
sp_world = get_ulysses_sequence_parallel_world_size(group)
if dim_size % sp_world != 0:
padding_size = sp_world - (dim_size % sp_world)
x = _pad_tensor(x, seq_dim, padding_size)
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
shape = list(x.shape)
shape[dim] = padding_size
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=dim)
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(0, -padding_size)
return x[slc]
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:
group = get_ulysses_sequence_parallel_group() if group is None else group
sp_world_size = dist.get_world_size(group)
sp_rank = get_ulysses_sequence_parallel_rank()
dim_size = x.size(dim)
# pad before slice
if padding and dim_size % sp_world_size:
padding_size = sp_world_size - (dim_size % sp_world_size)
x = _pad_tensor(x, dim, padding_size)
# slice the input tensor
parts = x.size(dim) // sp_world_size
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
return x[slc].contiguous()
def all_to_all_tensor(
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
):
group = get_ulysses_sequence_parallel_group() if group is None else group
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
if async_op:
def wait():
comm.wait()
return torch.cat(output_list, dim=gather_dim).contiguous()
return wait
return torch.cat(output_list, dim=gather_dim).contiguous()
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):
group = get_ulysses_sequence_parallel_group() if group is None else group
sp_world_size = dist.get_world_size(group=group)
output_shape = list(local_tensor.shape)
output_shape[0] = output_shape[0] * sp_world_size
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)
return output
class SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
async_op: bool = False,
) -> Tensor:
ctx.group = group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.async_op = async_op
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
if ctx.async_op:
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
else:
input_t = grad_output[0]
return (
None,
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
None,
None,
None,
None,
)
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
group: dist.ProcessGroup,
local_tensor: Tensor,
gather_dim: int,
grad_scaler: bool = True,
async_op=False) -> Tensor:
ctx.group = group
ctx.gather_dim = gather_dim
ctx.grad_scaler = grad_scaler
ctx.async_op = async_op
sp_world_size = dist.get_world_size(group=group)
ctx.sp_world_size = sp_world_size
sp_rank = dist.get_rank(group=group)
ctx.sp_rank = sp_rank
local_shape = list(local_tensor.size())
split_size = local_shape[0]
part_size = local_shape[gather_dim] # store original size
ctx.part_size = part_size
output = all_gather_tensor(local_tensor, group, async_op)
return torch.cat(output.split(split_size, dim=0), dim=gather_dim)
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Any:
if ctx.grad_scaler:
grad_output = grad_output * ctx.sp_world_size
return (None, grad_output.split(ctx.part_size,
dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), None, None, None, None)
def gather_outpus_and_unpad(x: Tensor,
gather_dim: int,
unpad_dim: int = None,
padding_size: int = 0,
grad_scaler: bool = True,
group: Optional[dist.ProcessGroup] = None):
group = get_ulysses_sequence_parallel_group() if group is None else group
sp_size = get_ulysses_sequence_parallel_world_size()
if group == None:
return x
x = Gather.apply(group, x, gather_dim, grad_scaler)
if unpad_dim is not None:
assert isinstance(padding_size, int), 'padding size is not given or is not an integer'
if padding_size == 0:
return x
x = _unpad_tensor(x, unpad_dim, padding_size)
return x
def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor,
position_ids_rmpad: Optional[torch.Tensor] = None,
sp_size: int = 1):
"""
Pad and slice input_ids to be divisible by sp_size
Pad position_ids to be divisible by sp_size.
Note both input_ids_rmpad and position_ids_rmpad will be padded,
but only input_ids will be sliced.
The is the utility of pre-forward for ulysses sequence parallelism
Args:
input_ids_rmpad: shape of [bsz, seqlen]
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
sp_size (int): ulysses sequence parallelism size
Returns:
torch.Tensor: padded and sliced input_ids
torch.Tensor: padded and sliced position_ids
int: pad size
"""
if position_ids_rmpad is not None:
assert position_ids_rmpad.size(0) == 1
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
if sp_size <= 1:
return input_ids_rmpad, position_ids_rmpad, 0
_, total_seq_len = input_ids_rmpad.shape
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
if pad_size > 0:
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
if position_ids_rmpad is not None:
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
# we don't need to slice position ids
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
return input_ids_rmpad, position_ids_rmpad, pad_size