Initial commit
This commit is contained in:
18
verl/utils/__init__.py
Normal file
18
verl/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import tokenizer
|
||||
from .tokenizer import *
|
||||
|
||||
__all__ = tokenizer.__all__
|
||||
23
verl/utils/config.py
Normal file
23
verl/utils/config.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def update_dict_with_config(dictionary: Dict, config: DictConfig):
|
||||
for key in dictionary:
|
||||
if hasattr(config, key):
|
||||
dictionary[key] = getattr(config, key)
|
||||
16
verl/utils/dataset/README.md
Normal file
16
verl/utils/dataset/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Dataset Format
|
||||
## RLHF dataset
|
||||
We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.
|
||||
|
||||
Math problems
|
||||
```json
|
||||
{
|
||||
"data_source": "openai/gsm8k",
|
||||
"prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}],
|
||||
"ability": "math",
|
||||
"reward_model": {
|
||||
"style": "rule",
|
||||
"ground_truth": ["72"]
|
||||
},
|
||||
}
|
||||
```
|
||||
16
verl/utils/dataset/__init__.py
Normal file
16
verl/utils/dataset/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .rl_dataset import RLHFDataset
|
||||
from .rm_dataset import RMDataset
|
||||
155
verl/utils/dataset/rl_dataset.py
Normal file
155
verl/utils/dataset/rl_dataset.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from omegaconf import ListConfig
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def collate_fn(data_list: list[dict]) -> dict:
|
||||
tensors = {}
|
||||
non_tensors = {}
|
||||
|
||||
for data in data_list:
|
||||
for key, val in data.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
if key not in tensors:
|
||||
tensors[key] = []
|
||||
tensors[key].append(val)
|
||||
else:
|
||||
if key not in non_tensors:
|
||||
non_tensors[key] = []
|
||||
non_tensors[key].append(val)
|
||||
|
||||
for key, val in tensors.items():
|
||||
tensors[key] = torch.stack(val, dim=0)
|
||||
|
||||
for key, val in non_tensors.items():
|
||||
non_tensors[key] = np.array(val, dtype=object)
|
||||
|
||||
output = {}
|
||||
output.update(tensors)
|
||||
output.update(non_tensors)
|
||||
return output
|
||||
|
||||
|
||||
class RLHFDataset(Dataset):
|
||||
"""
|
||||
We assume the dataset contains a column that contains prompts and other information
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key='prompt',
|
||||
max_prompt_length=1024,
|
||||
filter_prompts=True,
|
||||
cache_dir='~/.cache/verl/rlhf',
|
||||
chat_template_func=None,
|
||||
return_raw_chat=False,
|
||||
truncation='error'):
|
||||
if not isinstance(parquet_files, (List, ListConfig)):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.filter_prompts = filter_prompts
|
||||
|
||||
self.return_raw_chat = return_raw_chat
|
||||
self.chat_template_func = chat_template_func
|
||||
self.truncation = truncation
|
||||
|
||||
self._download()
|
||||
self._read_files_and_tokenize()
|
||||
|
||||
def _download(self):
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
for i, parquet_file in enumerate(self.parquet_files):
|
||||
self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
# read parquet files and cache
|
||||
dataframe = pd.read_parquet(parquet_file)
|
||||
dataframes.append(dataframe)
|
||||
self.dataframe = pd.concat(dataframes)
|
||||
|
||||
print(f'original dataset len: {len(self.dataframe)}')
|
||||
|
||||
# filter out too long prompts
|
||||
tokenizer = self.tokenizer
|
||||
prompt_key = self.prompt_key
|
||||
|
||||
# nvm if prompt is too long
|
||||
# self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
|
||||
# tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
|
||||
# axis=1)]
|
||||
|
||||
print(f'filter dataset len: {len(self.dataframe)}')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
Note that we also return the raw_input_ids so that it can be combined with other chat template
|
||||
"""
|
||||
row_dict = self.dataframe.iloc[item].to_dict()
|
||||
|
||||
chat = row_dict.pop(self.prompt_key)
|
||||
|
||||
if self.tokenizer.chat_template:
|
||||
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
|
||||
else:
|
||||
prompt_with_chat_template = chat[0]['content']
|
||||
# prompt_with_chat_template = chat
|
||||
|
||||
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation)
|
||||
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
row_dict['input_ids'] = input_ids[0]
|
||||
row_dict['attention_mask'] = attention_mask[0]
|
||||
row_dict['position_ids'] = position_ids[0]
|
||||
|
||||
# encode prompts without chat template
|
||||
if self.return_raw_chat:
|
||||
row_dict['raw_prompt'] = chat.tolist()
|
||||
|
||||
# add index for each prompt
|
||||
index = row_dict.get("extra_info", {}).get("index", 0)
|
||||
row_dict["index"] = index
|
||||
|
||||
return row_dict
|
||||
143
verl/utils/dataset/rm_dataset.py
Normal file
143
verl/utils/dataset/rm_dataset.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
|
||||
def download_files_distributed(download_fn):
|
||||
import torch.distributed
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0:
|
||||
# download files
|
||||
download_fn()
|
||||
|
||||
torch.distributed.barrier()
|
||||
else:
|
||||
# download anyway
|
||||
download_fn()
|
||||
|
||||
|
||||
class RMDataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer,
|
||||
prompt_key='prompt',
|
||||
chosen_key='chosen',
|
||||
rejected_key='rejected',
|
||||
max_length=1024,
|
||||
add_eos=True,
|
||||
cache_dir='~/.cache/verl/rm'):
|
||||
if not isinstance(parquet_files, List):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
if isinstance(tokenizer, str):
|
||||
tokenizer = hf_tokenizer(tokenizer)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.chosen_key = chosen_key
|
||||
self.rejected_key = rejected_key
|
||||
|
||||
self.add_eos = add_eos
|
||||
self.max_length = max_length
|
||||
|
||||
self._download()
|
||||
self._read_files_and_tokenize()
|
||||
|
||||
def _download(self):
|
||||
|
||||
def _download_files():
|
||||
from verl.utils.fs import copy, _is_non_local
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
assert os.path.exists(self.cache_dir)
|
||||
for i, parquet_file in enumerate(self.parquet_files):
|
||||
if _is_non_local(parquet_file):
|
||||
dst = os.path.join(self.cache_dir, os.path.basename(parquet_file))
|
||||
if not os.path.exists(dst):
|
||||
copy(src=parquet_file, dst=dst)
|
||||
self.parquet_files[i] = dst
|
||||
|
||||
download_files_distributed(_download_files)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
# read parquet files and cache
|
||||
dataframe = pd.read_parquet(parquet_file)
|
||||
dataframes.append(dataframe)
|
||||
self.dataframe = pd.concat(dataframes)
|
||||
self.prompts = self.dataframe[self.prompt_key].tolist()
|
||||
self.chosen_responses = self.dataframe[self.chosen_key].tolist()
|
||||
self.rejected_responses = self.dataframe[self.rejected_key].tolist()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def _pad_to_length(self, input_ids, attention_mask):
|
||||
curr_length = input_ids.shape[-1]
|
||||
|
||||
if curr_length < self.max_length:
|
||||
input_ids = torch.cat(
|
||||
(input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1)
|
||||
attention_mask = torch.cat(
|
||||
(attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)),
|
||||
dim=-1)
|
||||
elif curr_length > self.max_length:
|
||||
input_ids = input_ids[:self.max_length]
|
||||
attention_mask = attention_mask[:self.max_length]
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
def __getitem__(self, item):
|
||||
prompt = self.prompts[item]
|
||||
chosen_response = self.chosen_responses[item]
|
||||
rejected_response = self.rejected_responses[item]
|
||||
|
||||
prompt_ids = self.tokenizer(prompt, return_tensors='pt')['input_ids'][0]
|
||||
chosen_response_ids = self.tokenizer(chosen_response, return_tensors='pt')['input_ids'][0]
|
||||
rejected_response_ids = self.tokenizer(rejected_response, return_tensors='pt')['input_ids'][0]
|
||||
|
||||
if self.add_eos:
|
||||
chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1)
|
||||
rejected_response_ids = torch.cat((rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])),
|
||||
dim=-1)
|
||||
|
||||
chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1)
|
||||
chosen_attention_mask = torch.ones_like(chosen_input_ids)
|
||||
|
||||
rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1)
|
||||
rejected_attention_mask = torch.ones_like(rejected_input_ids)
|
||||
|
||||
chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask)
|
||||
rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask)
|
||||
|
||||
input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0)
|
||||
attention_mask = torch.stack((rejected_input_ids, rejected_attention_mask), dim=0)
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
}
|
||||
15
verl/utils/debug/__init__.py
Normal file
15
verl/utils/debug/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .performance import log_gpu_memory_usage
|
||||
30
verl/utils/debug/performance.py
Normal file
30
verl/utils/debug/performance.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import logging
|
||||
|
||||
|
||||
def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):
|
||||
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3
|
||||
|
||||
message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}'
|
||||
|
||||
if logger is None:
|
||||
print(message)
|
||||
else:
|
||||
logger.log(msg=message, level=level)
|
||||
108
verl/utils/debug/trajectory_tracker.py
Normal file
108
verl/utils/debug/trajectory_tracker.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Trajectory tracker can be inserted into code to save the intermediate results.
|
||||
The results will be dump to hdfs for offline comparison.
|
||||
Each process will have a client that first move all the tensors to CPU
|
||||
"""
|
||||
|
||||
from verl.utils.hdfs_io import makedirs, copy
|
||||
import torch
|
||||
import os
|
||||
import ray
|
||||
import io
|
||||
import tempfile
|
||||
|
||||
from collections import deque
|
||||
|
||||
remote_copy = ray.remote(copy)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):
|
||||
filename = name + '.pth'
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_filepath = os.path.join(tmpdirname, filename)
|
||||
with open(local_filepath, 'wb') as f:
|
||||
f.write(data.getbuffer())
|
||||
# upload to hdfs
|
||||
|
||||
if verbose:
|
||||
print(f'Saving {local_filepath} to {hdfs_dir}')
|
||||
try:
|
||||
copy(local_filepath, hdfs_dir)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class TrajectoryTracker():
|
||||
|
||||
def __init__(self, hdfs_dir, verbose) -> None:
|
||||
self.hdfs_dir = hdfs_dir
|
||||
makedirs(hdfs_dir)
|
||||
self.verbose = verbose
|
||||
|
||||
self.handle = deque()
|
||||
|
||||
def dump(self, data: io.BytesIO, name):
|
||||
# get a temp file and write to it
|
||||
self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))
|
||||
|
||||
def wait_for_hdfs(self):
|
||||
while len(self.handle) != 0:
|
||||
future = self.handle.popleft()
|
||||
ray.get(future)
|
||||
|
||||
|
||||
def dump_data(data, name):
|
||||
enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1'
|
||||
if not enable:
|
||||
return
|
||||
buffer = io.BytesIO()
|
||||
torch.save(data, buffer)
|
||||
tracker = get_trajectory_tracker()
|
||||
ray.get(tracker.dump.remote(buffer, name))
|
||||
|
||||
|
||||
def get_trajectory_tracker():
|
||||
hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None)
|
||||
verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1'
|
||||
assert hdfs_dir is not None
|
||||
tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True,
|
||||
lifetime="detached").remote(hdfs_dir, verbose)
|
||||
return tracker
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# testing
|
||||
os.environ['VERL_ENABLE_TRACKER'] = '1'
|
||||
os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test'
|
||||
|
||||
@ray.remote
|
||||
def process(iter):
|
||||
data = {'obs': torch.randn(10, 20)}
|
||||
dump_data(data, f'process_{iter}_obs')
|
||||
|
||||
ray.init()
|
||||
|
||||
output_lst = []
|
||||
|
||||
for i in range(10):
|
||||
output_lst.append(process.remote(i))
|
||||
|
||||
out = ray.get(output_lst)
|
||||
|
||||
tracker = get_trajectory_tracker()
|
||||
ray.get(tracker.wait_for_hdfs.remote())
|
||||
28
verl/utils/distributed.py
Normal file
28
verl/utils/distributed.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for distributed training."""
|
||||
import os
|
||||
|
||||
|
||||
def initialize_global_process_group(timeout_second=36000):
|
||||
import torch.distributed
|
||||
from datetime import timedelta
|
||||
torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch.cuda.set_device(local_rank)
|
||||
return local_rank, rank, world_size
|
||||
123
verl/utils/flops_counter.py
Normal file
123
verl/utils/flops_counter.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig, Qwen2Config, LlamaConfig
|
||||
|
||||
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
|
||||
|
||||
|
||||
def get_device_flops(unit="T"):
|
||||
|
||||
def unit_convert(number, level):
|
||||
units = ["B", "K", "M", "G", "T", "P"]
|
||||
if number <= 0:
|
||||
return number
|
||||
ptr = 0
|
||||
while ptr < len(units) and units[ptr] != level:
|
||||
number /= 1000
|
||||
ptr += 1
|
||||
return number
|
||||
|
||||
device_name = torch.cuda.get_device_name()
|
||||
flops = float("inf") # INF flops for unkown gpu type
|
||||
if "H100" in device_name or "H800" in device_name:
|
||||
flops = 989e12
|
||||
elif "A100" in device_name or "A800" in device_name:
|
||||
flops = 312e12
|
||||
elif "L40" in device_name:
|
||||
flops = 181.05e12
|
||||
elif "L20" in device_name:
|
||||
flops = 119.5e12
|
||||
elif "H20" in device_name:
|
||||
flops = 148e12
|
||||
elif "910B" in device_name:
|
||||
flops = 354e12
|
||||
flops_unit = unit_convert(flops, unit)
|
||||
return flops_unit
|
||||
|
||||
|
||||
class FlopsCounter:
|
||||
"""
|
||||
Used to count mfu during training loop
|
||||
|
||||
Example:
|
||||
flops_counter = FlopsCounter(config)
|
||||
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
if not isinstance(config, VALID_CONFIG_TYPE):
|
||||
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. "
|
||||
f"MFU will always be zero.")
|
||||
|
||||
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops}
|
||||
self.config = config
|
||||
|
||||
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
|
||||
return 0
|
||||
|
||||
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
|
||||
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
|
||||
hidden_size = self.config.hidden_size
|
||||
vocab_size = self.config.vocab_size
|
||||
num_hidden_layers = self.config.num_hidden_layers
|
||||
num_key_value_heads = self.config.num_key_value_heads
|
||||
num_attention_heads = self.config.num_attention_heads
|
||||
intermediate_size = self.config.intermediate_size
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
q_size = num_attention_heads * head_dim
|
||||
k_size = num_key_value_heads * head_dim
|
||||
v_size = num_key_value_heads * head_dim
|
||||
|
||||
# non-attn per layer parm
|
||||
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
|
||||
mlp_N = hidden_size * intermediate_size * 3
|
||||
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
|
||||
emd_and_lm_head_N = vocab_size * hidden_size * 2
|
||||
# non-attn all_layer parm
|
||||
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
|
||||
# non-attn all_layer & all_token fwd & bwd flops
|
||||
dense_N_flops = 6 * dense_N * tokens_sum
|
||||
|
||||
# attn all_layer & all_token fwd & bwd flops
|
||||
seqlen_square_sum = 0
|
||||
for seqlen in batch_seqlens:
|
||||
seqlen_square_sum += seqlen * seqlen
|
||||
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
|
||||
|
||||
# all_layer & all_token fwd & bwd flops
|
||||
flops_all_token = dense_N_flops + attn_qkv_flops
|
||||
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
|
||||
return flops_achieved
|
||||
|
||||
def estimate_flops(self, batch_seqlens, delta_time):
|
||||
"""
|
||||
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
|
||||
|
||||
Args:
|
||||
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
|
||||
delta_time (float): The time taken to process the batch, in seconds.
|
||||
|
||||
Returns:
|
||||
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
|
||||
promised_flops (float): The expected FLOPS of the current device.
|
||||
"""
|
||||
tokens_sum = sum(batch_seqlens)
|
||||
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
|
||||
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
|
||||
promised_flops = get_device_flops()
|
||||
return estimated_flops, promised_flops
|
||||
88
verl/utils/fs.py
Normal file
88
verl/utils/fs.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""File-system agnostic IO APIs"""
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
|
||||
from .hdfs_io import copy, makedirs, exists
|
||||
|
||||
__all__ = ["copy", "exists", "makedirs"]
|
||||
|
||||
_HDFS_PREFIX = "hdfs://"
|
||||
|
||||
|
||||
def _is_non_local(path):
|
||||
return path.startswith(_HDFS_PREFIX)
|
||||
|
||||
|
||||
def md5_encode(path: str) -> str:
|
||||
return hashlib.md5(path.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:
|
||||
"""Return a local temp path that joins cache_dir and basename of hdfs_path
|
||||
|
||||
Args:
|
||||
hdfs_path:
|
||||
cache_dir:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# make a base64 encoding of hdfs_path to avoid directory conflict
|
||||
encoded_hdfs_path = md5_encode(hdfs_path)
|
||||
temp_dir = os.path.join(cache_dir, encoded_hdfs_path)
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
dst = os.path.join(temp_dir, os.path.basename(hdfs_path))
|
||||
return dst
|
||||
|
||||
|
||||
def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str:
|
||||
"""Copy src from hdfs to local if src is on hdfs or directly return src.
|
||||
If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if
|
||||
the src name is the same between calls
|
||||
|
||||
Args:
|
||||
src (str): a HDFS path of a local path
|
||||
|
||||
Returns:
|
||||
a local path of the copied file
|
||||
"""
|
||||
from filelock import FileLock
|
||||
|
||||
assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}'
|
||||
|
||||
if _is_non_local(src):
|
||||
# download from hdfs to local
|
||||
if cache_dir is None:
|
||||
# get a temp folder
|
||||
cache_dir = tempfile.gettempdir()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
assert os.path.exists(cache_dir)
|
||||
local_path = get_local_temp_path(src, cache_dir)
|
||||
# get a specific lock
|
||||
filelock = md5_encode(src) + '.lock'
|
||||
lock_file = os.path.join(cache_dir, filelock)
|
||||
with FileLock(lock_file=lock_file):
|
||||
if not os.path.exists(local_path):
|
||||
if verbose:
|
||||
print(f'Copy from {src} to {local_path}')
|
||||
copy(src, local_path)
|
||||
return local_path
|
||||
else:
|
||||
return src
|
||||
329
verl/utils/fsdp_utils.py
Normal file
329
verl/utils/fsdp_utils.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
import itertools
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
from transformers.trainer_pt_utils import get_module_class_from_name
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_fn(x: torch.nn.Module):
|
||||
if not torch.distributed.get_rank() == 0:
|
||||
x = x.to_empty(device=torch.cuda.current_device(), recurse=False)
|
||||
torch.cuda.empty_cache()
|
||||
return x
|
||||
|
||||
|
||||
def get_init_weight_context_manager(use_meta_tensor=True):
|
||||
from accelerate import init_empty_weights
|
||||
cpu_init_weights = lambda: torch.device('cpu')
|
||||
if use_meta_tensor:
|
||||
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
|
||||
else:
|
||||
init_context = cpu_init_weights
|
||||
return init_context
|
||||
|
||||
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py
|
||||
def get_fsdp_wrap_policy(module, config=None, is_lora=False):
|
||||
"""Get FSDP wrap policy for the module.
|
||||
|
||||
Args:
|
||||
module: The module to get wrap policy for
|
||||
config: Configuration for wrap policy
|
||||
is_lora: Whether to enable lambda policy for LoRA modules
|
||||
"""
|
||||
if config is None:
|
||||
config = {}
|
||||
|
||||
if config.get('disable', False):
|
||||
return None
|
||||
|
||||
default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
|
||||
fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap",
|
||||
default_transformer_cls_names_to_wrap)
|
||||
min_num_params = config.get('min_num_params', 0)
|
||||
auto_wrap_policy = None
|
||||
|
||||
policies = []
|
||||
|
||||
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
|
||||
# Add lambda policy for LoRA modules if is_lora is True
|
||||
if is_lora:
|
||||
|
||||
def lambda_policy_fn(module):
|
||||
if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and
|
||||
module.weight.requires_grad):
|
||||
return True
|
||||
return False
|
||||
|
||||
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
||||
policies.append(lambda_policy)
|
||||
|
||||
if min_num_params > 0:
|
||||
size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
|
||||
policies.append(size_policy)
|
||||
elif fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in fsdp_transformer_layer_cls_to_wrap:
|
||||
transformer_cls = get_module_class_from_name(module, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
|
||||
transformer_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
policies.append(transformer_policy)
|
||||
|
||||
if len(policies) > 0:
|
||||
auto_wrap_policy = functools.partial(_or_policy, policies=policies)
|
||||
|
||||
return auto_wrap_policy
|
||||
|
||||
|
||||
def offload_fsdp_grad(module):
|
||||
for _, param in module.named_parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_fsdp_grad(module, device_id):
|
||||
for _, param in module.named_parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(device_id, non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def offload_fsdp_param_and_grad(module, offload_grad=False):
|
||||
for _, param in module.named_parameters():
|
||||
if hasattr(param, "_local_shard"):
|
||||
param._local_shard = param._local_shard.to("cpu", non_blocking=True)
|
||||
param.data = param.data.to('cpu', non_blocking=True)
|
||||
if offload_grad and param.grad is not None:
|
||||
param.grad = param.grad.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_fsdp_param_and_grad(module, device_id, load_grad=False):
|
||||
for _, param in module.named_parameters():
|
||||
if hasattr(param, "_local_shard"):
|
||||
param._local_shard = param._local_shard.to(device_id, non_blocking=True)
|
||||
param.data = param.data.to(device_id, non_blocking=True)
|
||||
if load_grad and param.grad is not None:
|
||||
param.grad = param.grad.to(device_id, non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def offload_fsdp_optimizer(optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
state = optimizer.state[param]
|
||||
for key, value in state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
state[key] = value.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_fsdp_optimizer(optimizer, device_id):
|
||||
for param_group in optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
state = optimizer.state[param]
|
||||
for key, value in state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
state[key] = value.to(device_id, non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def meta_device_init():
|
||||
"""
|
||||
Create model parameters with meta device.
|
||||
|
||||
Note buffers in model will still be initialized in default device (e.g., CPU),
|
||||
since the buffers can be non-persistent and filled with expected values that can
|
||||
NOT be captured in meta device.
|
||||
"""
|
||||
device = torch.device("meta")
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
registered = set()
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
# we will skip register shared parameters as it
|
||||
# is already registered previously
|
||||
if param is not None and param not in registered:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
registered.add(module._parameters[name])
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
yield
|
||||
finally:
|
||||
registered.clear()
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
|
||||
|
||||
def parallel_load_safetensors(filepath):
|
||||
"""
|
||||
Parallel load safetensors from huggingface checkpoint
|
||||
|
||||
Huggingface checkpoint contains:
|
||||
|
||||
- config.json: a json file for model configuration
|
||||
- model.safetensor.index.json: a json file for safetensors (parameters & buffers) index
|
||||
- model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks
|
||||
|
||||
Or (when model is small),
|
||||
|
||||
- model.safetensors: a binary file for all parameters and buffers
|
||||
|
||||
Each rank will own a part of model chunks and load them directly into GPU memory.
|
||||
"""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
safetensors2param = {}
|
||||
|
||||
index_file = os.path.join(filepath, "model.safetensors.index.json")
|
||||
if os.path.exists(index_file):
|
||||
index = json.load(open(index_file, "rb"))
|
||||
for param_name, filename in index["weight_map"].items():
|
||||
safetensors2param.setdefault(filename, []).append(param_name)
|
||||
else:
|
||||
# in this case, the model is small and we can load it all at once
|
||||
param_file = os.path.join(filepath, "model.safetensors")
|
||||
assert os.path.exists(param_file), f"Cannot find {param_file}"
|
||||
states = load_file(param_file)
|
||||
for param_name in states:
|
||||
safetensors2param.setdefault("model.safetensors", []).append(param_name)
|
||||
del states
|
||||
|
||||
total_files = len(safetensors2param)
|
||||
ckpt_chunks = sorted(safetensors2param.keys())
|
||||
world_size = dist.get_world_size()
|
||||
size = int(math.ceil(total_files / world_size))
|
||||
ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)]
|
||||
|
||||
shard_states = {}
|
||||
device = torch.cuda.current_device()
|
||||
for rank, files in enumerate(ckpt_chunks):
|
||||
if rank == dist.get_rank():
|
||||
for file in files:
|
||||
file = os.path.join(filepath, file)
|
||||
states = load_file(file, device=device)
|
||||
# print(f"rank {rank} loading {file}...")
|
||||
shard_states.update(states)
|
||||
else:
|
||||
for file in files:
|
||||
for param_name in safetensors2param[file]:
|
||||
shard_states[param_name] = rank
|
||||
return shard_states
|
||||
|
||||
|
||||
def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):
|
||||
"""
|
||||
Generate a function to initialize sub-modules in the `module` with `shard_states`
|
||||
from huggingface checkpoint.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the global module to be initialized
|
||||
shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint
|
||||
|
||||
Returns:
|
||||
init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`
|
||||
"""
|
||||
|
||||
state2fqn = {}
|
||||
for name, state in itertools.chain(module.named_parameters(remove_duplicate=False),
|
||||
module.named_buffers(remove_duplicate=False)):
|
||||
state2fqn.setdefault(state, []).append(name)
|
||||
# remove standalone parameters and buffers
|
||||
shared = {s for s, names in state2fqn.items() if len(names) > 1}
|
||||
materialized_states = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def create_and_sync_state(param_name, state, is_param):
|
||||
assert param_name in shard_states, f"{param_name} not loaded"
|
||||
device = torch.cuda.current_device()
|
||||
if is_param:
|
||||
param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)
|
||||
else: # buffer
|
||||
param = torch.empty_like(state.data, device=device)
|
||||
loaded = shard_states[param_name]
|
||||
if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
|
||||
# NOTE: loaded.dtype can be different with param.dtype
|
||||
param.data.copy_(loaded.data)
|
||||
dist.broadcast(param.data, src=dist.get_rank())
|
||||
else:
|
||||
assert isinstance(loaded, int) # the rank that holds the state
|
||||
dist.broadcast(param.data, src=loaded)
|
||||
shard_states.pop(param_name)
|
||||
del loaded
|
||||
return param
|
||||
|
||||
def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):
|
||||
param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))
|
||||
# param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])
|
||||
for name, state in param_and_buffers:
|
||||
if not state.is_meta:
|
||||
continue
|
||||
is_param = name in sub_mod._parameters
|
||||
fqn = state2fqn[state].pop(0)
|
||||
# non-persistent buffers will not be saved in state dict, we can safely skip it
|
||||
if (not is_param) and fqn not in shard_states:
|
||||
if state.is_meta:
|
||||
raise RuntimeError(
|
||||
f"find a non-persistent buffer ({fqn}) initiated with device meta. "
|
||||
"Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.")
|
||||
continue
|
||||
# for shared parameter, we get it from the first time it is created
|
||||
if state in shared:
|
||||
if state not in materialized_states:
|
||||
materialized_states[state] = create_and_sync_state(fqn, state, is_param)
|
||||
else:
|
||||
if fqn in shard_states:
|
||||
shard_states.pop(fqn)
|
||||
materialize_state = materialized_states[state]
|
||||
# for not shared parameter, we create it directly
|
||||
else:
|
||||
materialize_state = create_and_sync_state(fqn, state, is_param)
|
||||
if is_param:
|
||||
sub_mod._parameters[name] = materialize_state
|
||||
else:
|
||||
sub_mod._buffers[name] = materialize_state
|
||||
if recurse:
|
||||
for module in sub_mod.children():
|
||||
init_fn(module, recurse=True)
|
||||
|
||||
# for debug
|
||||
# if len(shard_states) == 0: print("clear")
|
||||
return sub_mod
|
||||
|
||||
return init_fn
|
||||
144
verl/utils/hdfs_io.py
Normal file
144
verl/utils/hdfs_io.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
|
||||
|
||||
_HDFS_PREFIX = "hdfs://"
|
||||
|
||||
_HDFS_BIN_PATH = shutil.which('hdfs')
|
||||
|
||||
|
||||
def exists(path: str, **kwargs) -> bool:
|
||||
r"""Works like os.path.exists() but supports hdfs.
|
||||
|
||||
Test whether a path exists. Returns False for broken symbolic links.
|
||||
|
||||
Args:
|
||||
path (str): path to test
|
||||
|
||||
Returns:
|
||||
bool: True if the path exists, False otherwise
|
||||
"""
|
||||
if _is_non_local(path):
|
||||
return _exists(path, **kwargs)
|
||||
return os.path.exists(path)
|
||||
|
||||
|
||||
def _exists(file_path: str):
|
||||
""" hdfs capable to check whether a file_path is exists """
|
||||
if file_path.startswith("hdfs"):
|
||||
return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0
|
||||
return os.path.exists(file_path)
|
||||
|
||||
|
||||
def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:
|
||||
r"""Works like os.makedirs() but supports hdfs.
|
||||
|
||||
Super-mkdir; create a leaf directory and all intermediate ones. Works like
|
||||
mkdir, except that any intermediate path segment (not just the rightmost)
|
||||
will be created if it does not exist. If the target directory already
|
||||
exists, raise an OSError if exist_ok is False. Otherwise no exception is
|
||||
raised. This is recursive.
|
||||
|
||||
Args:
|
||||
name (str): directory to create
|
||||
mode (int): file mode bits
|
||||
exist_ok (bool): if True, do not raise an exception if the directory already exists
|
||||
kwargs: keyword arguments for hdfs
|
||||
|
||||
"""
|
||||
if _is_non_local(name):
|
||||
# TODO(haibin.lin):
|
||||
# - handle OSError for hdfs(?)
|
||||
# - support exist_ok for hdfs(?)
|
||||
_mkdir(name, **kwargs)
|
||||
else:
|
||||
os.makedirs(name, mode=mode, exist_ok=exist_ok)
|
||||
|
||||
|
||||
def _mkdir(file_path: str) -> bool:
|
||||
"""hdfs mkdir"""
|
||||
if file_path.startswith("hdfs"):
|
||||
_run_cmd(_hdfs_cmd(f"-mkdir -p {file_path}"))
|
||||
else:
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
return True
|
||||
|
||||
|
||||
def copy(src: str, dst: str, **kwargs) -> bool:
|
||||
r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs.
|
||||
|
||||
Copy data and mode bits ("cp src dst"). Return the file's destination.
|
||||
The destination may be a directory.
|
||||
If source and destination are the same file, a SameFileError will be
|
||||
raised.
|
||||
|
||||
Arg:
|
||||
src (str): source file path
|
||||
dst (str): destination file path
|
||||
kwargs: keyword arguments for hdfs copy
|
||||
|
||||
Returns:
|
||||
str: destination file path
|
||||
|
||||
"""
|
||||
if _is_non_local(src) or _is_non_local(dst):
|
||||
# TODO(haibin.lin):
|
||||
# - handle SameFileError for hdfs files(?)
|
||||
# - return file destination for hdfs files
|
||||
return _copy(src, dst)
|
||||
else:
|
||||
if os.path.isdir(src):
|
||||
return shutil.copytree(src, dst, **kwargs)
|
||||
else:
|
||||
return shutil.copy(src, dst, **kwargs)
|
||||
|
||||
|
||||
def _copy(from_path: str, to_path: str, timeout: int = None) -> bool:
|
||||
if to_path.startswith("hdfs"):
|
||||
if from_path.startswith("hdfs"):
|
||||
returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout)
|
||||
else:
|
||||
returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout)
|
||||
else:
|
||||
if from_path.startswith("hdfs"):
|
||||
returncode = _run_cmd(_hdfs_cmd(f"-get \
|
||||
{from_path} {to_path}"), timeout=timeout)
|
||||
else:
|
||||
try:
|
||||
shutil.copy(from_path, to_path)
|
||||
returncode = 0
|
||||
except shutil.SameFileError:
|
||||
returncode = 0
|
||||
except Exception as e:
|
||||
logger.warning(f"copy {from_path} {to_path} failed: {e}")
|
||||
returncode = -1
|
||||
return returncode == 0
|
||||
|
||||
|
||||
def _run_cmd(cmd: str, timeout=None):
|
||||
return os.system(cmd)
|
||||
|
||||
|
||||
def _hdfs_cmd(cmd: str) -> str:
|
||||
return f"{_HDFS_BIN_PATH} dfs {cmd}"
|
||||
|
||||
|
||||
def _is_non_local(path: str):
|
||||
return path.startswith(_HDFS_PREFIX)
|
||||
48
verl/utils/import_utils.py
Normal file
48
verl/utils/import_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities to check if packages are available.
|
||||
We assume package availability won't change during runtime.
|
||||
"""
|
||||
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
|
||||
@cache
|
||||
def is_megatron_core_available():
|
||||
try:
|
||||
from megatron.core import parallel_state as mpu
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@cache
|
||||
def is_vllm_available():
|
||||
try:
|
||||
import vllm
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def import_external_libs(external_libs=None):
|
||||
if external_libs is None:
|
||||
return
|
||||
if not isinstance(external_libs, List):
|
||||
external_libs = [external_libs]
|
||||
import importlib
|
||||
for external_lib in external_libs:
|
||||
importlib.import_module(external_lib)
|
||||
13
verl/utils/megatron/__init__.py
Normal file
13
verl/utils/megatron/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
41
verl/utils/megatron/memory.py
Normal file
41
verl/utils/megatron/memory.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
|
||||
def __init__(self, numel, numel_padded, dtype):
|
||||
self.numel = numel
|
||||
self.numel_padded = numel_padded
|
||||
self.dtype = dtype
|
||||
self.data = torch.zeros(self.numel_padded,
|
||||
dtype=self.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
|
||||
def zero(self):
|
||||
"""Reset the buffer to zero."""
|
||||
self.data.zero_()
|
||||
|
||||
def get(self, shape, start_index):
|
||||
"""Return a tensor with the input `shape` as a view into the
|
||||
1-D data starting at `start_index`."""
|
||||
end_index = start_index + shape.numel()
|
||||
assert end_index <= self.numel, \
|
||||
'requested tensor is out of the buffer range.'
|
||||
buffer_tensor = self.data[start_index:end_index]
|
||||
buffer_tensor = buffer_tensor.view(shape)
|
||||
return buffer_tensor
|
||||
92
verl/utils/megatron/optimizer.py
Normal file
92
verl/utils/megatron/optimizer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from apex.optimizers import FusedAdam as Adam
|
||||
from apex.optimizers import FusedSGD as SGD
|
||||
from megatron.optimizer.distrib_optimizer import DistributedOptimizer
|
||||
from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
|
||||
from megatron.optimizer import get_param_groups
|
||||
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
|
||||
|
||||
def get_megatron_optimizer(
|
||||
model,
|
||||
config: OptimizerConfig,
|
||||
no_weight_decay_cond=None,
|
||||
scale_lr_cond=None,
|
||||
lr_mult=1.0,
|
||||
check_for_nan_in_loss_and_grad=False,
|
||||
overlap_param_gather=False # add for verl
|
||||
):
|
||||
# Base optimizer.
|
||||
param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult)
|
||||
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = Adam(param_groups,
|
||||
lr=config.lr,
|
||||
weight_decay=config.weight_decay,
|
||||
betas=(config.adam_beta1, config.adam_beta2),
|
||||
eps=config.adam_eps)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = SGD(param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum)
|
||||
else:
|
||||
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
|
||||
|
||||
# Determine whether the params have main-grad field.
|
||||
params_have_main_grad = True
|
||||
|
||||
# Mixed precision optimizer.
|
||||
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
|
||||
# from the MixedPrecisionOptimizer, which manages any optimizer where
|
||||
# the model params and main params are distinct.
|
||||
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
|
||||
|
||||
# Grad scaler:
|
||||
# if loss-scale is provided, instantiate the constant scaler.
|
||||
# if we are using fp16 and loss-scale is not present, use a
|
||||
# dynamic scaler.
|
||||
# otherwise we are running in bf16 with no loss-scale so
|
||||
# leave it as None.
|
||||
grad_scaler = None
|
||||
|
||||
# Constant loss scale.
|
||||
if config.loss_scale:
|
||||
grad_scaler = ConstantGradScaler(config.loss_scale)
|
||||
|
||||
# Dynamic loss scale.
|
||||
else:
|
||||
if config.fp16:
|
||||
grad_scaler = DynamicGradScaler(initial_scale=config.initial_loss_scale,
|
||||
min_scale=config.min_loss_scale,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=config.loss_scale_window,
|
||||
hysteresis=config.hysteresis)
|
||||
|
||||
# Megatron optimizer.
|
||||
if config.use_distributed_optimizer:
|
||||
return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
|
||||
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16,
|
||||
config.params_dtype, grad_scaler, model, overlap_param_gather)
|
||||
else:
|
||||
return Float16OptimizerWithFloat16Params(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
|
||||
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16,
|
||||
config.bf16, config.params_dtype, grad_scaler, model)
|
||||
|
||||
# FP32.
|
||||
return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad,
|
||||
params_have_main_grad, model)
|
||||
129
verl/utils/megatron/optimizer_config.py
Normal file
129
verl/utils/megatron/optimizer_config.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for optimizer."""
|
||||
|
||||
##############
|
||||
# General
|
||||
##############
|
||||
optimizer: str = 'adam'
|
||||
"""Optimizer to use (one of Adam or SGD)."""
|
||||
|
||||
lr: Optional[float] = None
|
||||
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
|
||||
iteration would be different.
|
||||
"""
|
||||
|
||||
min_lr: Optional[float] = None
|
||||
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
|
||||
|
||||
decoupled_lr: Optional[float] = None
|
||||
"""Separate learning rate for the input and output layer."""
|
||||
|
||||
decoupled_min_lr: Optional[float] = None
|
||||
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
|
||||
below this threshold.
|
||||
"""
|
||||
|
||||
weight_decay: float = 0.01
|
||||
"""Weight decay coefficient for L2 regularization."""
|
||||
|
||||
##############
|
||||
# Precision
|
||||
##############
|
||||
fp16: bool = False
|
||||
"""If true, train with fp16 mixed precision training. Defaults to False."""
|
||||
|
||||
bf16: bool = False
|
||||
"""If true, train with bf16 mixed precision training. Defaults to False."""
|
||||
|
||||
params_dtype: torch.dtype = torch.float32
|
||||
"""dtype used when intializing the weights. Defaults to torch.float32."""
|
||||
|
||||
###############
|
||||
# Loss scaling
|
||||
###############
|
||||
loss_scale: Optional[float] = None
|
||||
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
|
||||
dynamic loss scaling is used.
|
||||
"""
|
||||
|
||||
initial_loss_scale: float = 2**32
|
||||
"""Initial loss-scale for dynamic loss scaling."""
|
||||
|
||||
min_loss_scale: float = 1.0
|
||||
"""Minimum loss scale for dynamic loss scaling."""
|
||||
|
||||
loss_scale_window: float = 1000
|
||||
"""Window over which to raise/lower dynamic scale."""
|
||||
|
||||
hysteresis: int = 2
|
||||
"""Hysteresis for dynamic loss scaling."""
|
||||
|
||||
##############
|
||||
# Optimizer
|
||||
##############
|
||||
# Adam
|
||||
adam_beta1: float = 0.9
|
||||
"""First coefficient for computing running averages of gradient and its square in Adam
|
||||
optimizer.
|
||||
"""
|
||||
|
||||
adam_beta2: float = 0.999
|
||||
"""Second coefficient for computing running averages of gradient and its square in Adam
|
||||
optimizer.
|
||||
"""
|
||||
|
||||
adam_eps: float = 1e-08
|
||||
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
|
||||
|
||||
# SGD.
|
||||
sgd_momentum: float = 0.9
|
||||
"""Momentum factor for SGD optimizer."""
|
||||
|
||||
#######################
|
||||
# Distributed optimizer
|
||||
#######################
|
||||
use_distributed_optimizer: bool = False
|
||||
"""Distribute optimizer state over data-parallel replicas."""
|
||||
|
||||
overlap_grad_reduce: bool = False
|
||||
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
|
||||
|
||||
overlap_param_gather: bool = False
|
||||
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
|
||||
|
||||
################
|
||||
# Miscellaneous
|
||||
################
|
||||
clip_grad: float = 1.0
|
||||
"""Gradient clipping based on global L2 norm."""
|
||||
|
||||
log_num_zeros_in_grad: bool = False
|
||||
"""If true, calculate and log the number of zeros in gradient."""
|
||||
|
||||
barrier_with_L1_time: bool = False
|
||||
"""If true, use barrier with level 1 time measurements."""
|
||||
|
||||
timers: Callable = None
|
||||
"""Function to get timers."""
|
||||
51
verl/utils/megatron/pipeline_parallel.py
Normal file
51
verl/utils/megatron/pipeline_parallel.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from megatron.core import parallel_state as mpu
|
||||
|
||||
from .sequence_parallel import pad_to_sequence_parallel
|
||||
|
||||
|
||||
def compute_transformers_input_shapes(batches, meta_info):
|
||||
from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
|
||||
# pre-compute input shapes for each micro-batch at each pp stage
|
||||
input_shapes = []
|
||||
for model_inputs in batches:
|
||||
input_ids = model_inputs['input_ids']
|
||||
attention_mask = model_inputs['attention_mask']
|
||||
input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
|
||||
if meta_info['sequence_parallel']:
|
||||
input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
|
||||
# compute shapes for model_inputs
|
||||
input_shapes.append(
|
||||
torch.Size([
|
||||
input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
|
||||
]))
|
||||
else:
|
||||
# compute shapes for model_inputs
|
||||
input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
|
||||
return input_shapes
|
||||
|
||||
|
||||
def make_batch_generator(batches, vpp_size):
|
||||
if vpp_size > 1:
|
||||
# has vpp
|
||||
batch_generator = [batches] * vpp_size # number of vpp chunks
|
||||
batch_generator = [iter(b) for b in batch_generator]
|
||||
else:
|
||||
# no vpp
|
||||
batch_generator = iter(batches)
|
||||
return batch_generator
|
||||
54
verl/utils/megatron/sequence_parallel.py
Normal file
54
verl/utils/megatron/sequence_parallel.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from megatron.core import parallel_state as mpu
|
||||
|
||||
|
||||
def mark_parameter_as_sequence_parallel(parameter):
|
||||
setattr(parameter, 'sequence_parallel', True)
|
||||
|
||||
|
||||
def is_sequence_parallel_param(param):
|
||||
return hasattr(param, 'sequence_parallel') and param.sequence_parallel
|
||||
|
||||
|
||||
def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):
|
||||
"""pad the tokens such that the total length is a multiple of sp world size
|
||||
|
||||
Args:
|
||||
unpad_tokens: (total_nnz, ...). Tokens after removing padding
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
total_nnz = unpad_tokens.shape[0]
|
||||
sp_world_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
if total_nnz % sp_world_size == 0:
|
||||
pad_size = 0
|
||||
else:
|
||||
pad_size = sp_world_size - total_nnz % sp_world_size
|
||||
|
||||
if pad_size > 0:
|
||||
if unpad_tokens.ndim == 1:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
|
||||
elif unpad_tokens.ndim == 2:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
|
||||
else:
|
||||
raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
|
||||
|
||||
return unpad_tokens
|
||||
184
verl/utils/megatron/tensor_parallel.py
Normal file
184
verl/utils/megatron/tensor_parallel.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for using tensor_parallel in megatron
|
||||
"""
|
||||
from typing import Dict
|
||||
import torch
|
||||
from torch.nn import init
|
||||
import torch.distributed as dist
|
||||
from megatron.core import ModelParallelConfig
|
||||
from megatron.core import parallel_state as mpu, tensor_parallel
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def update_kwargs_with_config(dictionary: Dict, config: ModelParallelConfig):
|
||||
dictionary['config'] = config
|
||||
return dictionary
|
||||
|
||||
|
||||
def get_default_kwargs_for_model_parallel_config():
|
||||
model_parallel_config_kwargs = {
|
||||
'params_dtype': torch.float32,
|
||||
'use_cpu_initialization': False,
|
||||
'perform_initialization': True,
|
||||
'gradient_accumulation_fusion': False,
|
||||
'sequence_parallel': False,
|
||||
}
|
||||
return model_parallel_config_kwargs
|
||||
|
||||
|
||||
def get_default_model_parallel_config():
|
||||
return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())
|
||||
|
||||
|
||||
def get_common_default_kwargs_for_parallel_linear():
|
||||
default_model_parallel_config = get_default_model_parallel_config()
|
||||
common_default_kwargs = {
|
||||
'init_method': init.xavier_normal_,
|
||||
'stride': 1,
|
||||
'keep_master_weight_for_test': False,
|
||||
'config': default_model_parallel_config,
|
||||
}
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_column_parallel_linear():
|
||||
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
|
||||
column_parallel_config_kwargs = {
|
||||
'async_tensor_model_parallel_allreduce': False,
|
||||
}
|
||||
model_parallel_config_kwargs.update(column_parallel_config_kwargs)
|
||||
column_default_kwargs = {
|
||||
'config': ModelParallelConfig(**model_parallel_config_kwargs),
|
||||
}
|
||||
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
|
||||
common_default_kwargs.update(column_default_kwargs)
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_row_parallel_linear():
|
||||
common_default_kwargs = get_common_default_kwargs_for_parallel_linear()
|
||||
return common_default_kwargs
|
||||
|
||||
|
||||
def get_default_kwargs_for_parallel_embedding():
|
||||
model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()
|
||||
embedding_default_kwargs = {
|
||||
'init_method': init.xavier_normal_,
|
||||
'config': ModelParallelConfig(**model_parallel_config_kwargs),
|
||||
}
|
||||
return embedding_default_kwargs
|
||||
|
||||
|
||||
def is_tensor_parallel_param(param):
|
||||
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel)
|
||||
|
||||
|
||||
def get_tensor_parallel_partition_dim(param):
|
||||
assert is_tensor_parallel_param(param)
|
||||
return param.partition_dim
|
||||
|
||||
|
||||
def get_tensor_parallel_partition_stride(param):
|
||||
assert is_tensor_parallel_param(param)
|
||||
return param.partition_stride
|
||||
|
||||
|
||||
class _VocabParallelEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
|
||||
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
|
||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())
|
||||
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
|
||||
normalized_exp_logits = normalized_vocab_parallel_logits.exp()
|
||||
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
|
||||
dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())
|
||||
softmax_logits = normalized_exp_logits / normalized_sum_exp_logits
|
||||
sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True)
|
||||
dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())
|
||||
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
|
||||
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
|
||||
return entropy.squeeze(dim=-1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
||||
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
|
||||
grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits)
|
||||
return grad_input
|
||||
|
||||
|
||||
def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute entropy when the logits are sharded in tp ranks
|
||||
|
||||
Args:
|
||||
vocab_parallel_logits: (total_nnz, vocab_size // tp_size)
|
||||
|
||||
Returns: (total_nnz,)
|
||||
|
||||
"""
|
||||
return _VocabParallelEntropy.apply(vocab_parallel_logits)
|
||||
|
||||
|
||||
def vocab_parallel_log_probs_from_logits(logits, labels):
|
||||
"""TODO(zhangchi.usc1992): We may change the implementation later"""
|
||||
return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)
|
||||
|
||||
|
||||
def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
|
||||
"""Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region.
|
||||
This will further reduce the peak memory usage during training
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seqlen]
|
||||
attention_mask: [batch_size, seqlen]
|
||||
logits_rmpad: [total_nnz, vocab_size // tp_size]
|
||||
response_length: int
|
||||
|
||||
"""
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
batch_size, seqlen = input_ids.shape
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
|
||||
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
|
||||
full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad,
|
||||
labels=input_ids_rmpad_rolled) # (total_nnz,)
|
||||
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
|
||||
return output
|
||||
|
||||
|
||||
def vocab_parallel_compute_entropy_loss(logits, eos_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
entropy: a scalar torch.Tensor
|
||||
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = vocab_parallel_entropy(logits)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
return entropy_loss
|
||||
253
verl/utils/megatron_utils.py
Normal file
253
verl/utils/megatron_utils.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pretrain utilities."""
|
||||
from typing import Any, Dict
|
||||
import time
|
||||
from omegaconf import DictConfig
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
from verl.utils.memory_buffer import build_memory_reference_from_module
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from megatron.core import mpu, tensor_parallel
|
||||
from megatron.core.utils import get_model_config
|
||||
from megatron.core.transformer import TransformerConfig
|
||||
from megatron.core.transformer.module import Float16Module
|
||||
# from megatron.core.distributed import DistributedDataParallelConfig
|
||||
from megatron.core.distributed import DistributedDataParallel as DDP
|
||||
from megatron.core.enums import ModelType
|
||||
|
||||
|
||||
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
|
||||
"""Build the model."""
|
||||
# Build model.
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
|
||||
mpu.get_virtual_pipeline_model_parallel_world_size() is not None:
|
||||
assert model_type != ModelType.encoder_and_decoder, \
|
||||
"Interleaved schedule not supported for model with both encoder and decoder"
|
||||
model = []
|
||||
for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
||||
# Set pre_process and post_process only after virtual rank is set.
|
||||
pre_process = mpu.is_pipeline_first_stage()
|
||||
post_process = mpu.is_pipeline_last_stage()
|
||||
this_model = model_provider_func(pre_process=pre_process, post_process=post_process)
|
||||
this_model.model_type = model_type
|
||||
model.append(this_model)
|
||||
else:
|
||||
pre_process = mpu.is_pipeline_first_stage()
|
||||
post_process = mpu.is_pipeline_last_stage()
|
||||
add_encoder = True
|
||||
add_decoder = True
|
||||
if model_type == ModelType.encoder_and_decoder:
|
||||
if mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
assert mpu.get_pipeline_model_parallel_split_rank() is not None, \
|
||||
"Split rank needs to be specified for model with both encoder and decoder"
|
||||
rank = mpu.get_pipeline_model_parallel_rank()
|
||||
split_rank = mpu.get_pipeline_model_parallel_split_rank()
|
||||
world_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
pre_process = rank == 0 or rank == split_rank
|
||||
post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))
|
||||
add_encoder = mpu.is_pipeline_stage_before_split()
|
||||
add_decoder = mpu.is_pipeline_stage_after_split()
|
||||
model = model_provider_func(pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
add_encoder=add_encoder,
|
||||
add_decoder=add_decoder)
|
||||
else:
|
||||
model = model_provider_func(pre_process=pre_process, post_process=post_process)
|
||||
model.model_type = model_type
|
||||
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
|
||||
# Set tensor model parallel attributes if not set.
|
||||
# Only parameters that are already tensor model parallel have these
|
||||
# attributes set for them. We should make sure the default attributes
|
||||
# are set for all params so the optimizer can use them.
|
||||
for model_module in model:
|
||||
for param in model_module.parameters():
|
||||
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
|
||||
|
||||
# Print number of parameters.
|
||||
if mpu.get_data_parallel_rank() == 0:
|
||||
print(' > number of parameters on (tensor, pipeline) '
|
||||
'model parallel rank ({}, {}): {}'.format(
|
||||
mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
|
||||
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])),
|
||||
flush=True)
|
||||
|
||||
# GPU allocation.
|
||||
for model_module in model:
|
||||
model_module.cuda(torch.cuda.current_device())
|
||||
|
||||
# Fp16 conversion.
|
||||
config = get_model_config(model[0])
|
||||
if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel
|
||||
model = [Float16Module(config, model_module) for model_module in model]
|
||||
|
||||
if wrap_with_ddp:
|
||||
model = [
|
||||
DDP(config=config,
|
||||
module=model_chunk,
|
||||
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
|
||||
accumulate_allreduce_grads_in_fp32=True,
|
||||
overlap_grad_reduce=False,
|
||||
use_distributed_optimizer=True,
|
||||
disable_bucketing=(model_chunk_idx > 0)) for (model_chunk_idx, model_chunk) in enumerate(model)
|
||||
]
|
||||
# # Broadcast params from data parallel src rank to other data parallel ranks.
|
||||
# if args.data_parallel_random_init:
|
||||
for model_module in model:
|
||||
model_module.broadcast_params()
|
||||
return model
|
||||
|
||||
|
||||
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
|
||||
|
||||
|
||||
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
|
||||
return_list = True
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
return_list = False
|
||||
unwrapped_model = []
|
||||
for model_module in model:
|
||||
while isinstance(model_module, module_instances):
|
||||
model_module = model_module.module
|
||||
unwrapped_model.append(model_module)
|
||||
if not return_list:
|
||||
return unwrapped_model[0]
|
||||
return unwrapped_model
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
|
||||
print(f'megatron config {megatron_config}')
|
||||
dt = PrecisionType.to_dtype(megatron_config['param_dtype'])
|
||||
print(f'pipeline_dtype=megatron_config {dt}')
|
||||
transformer_config = TransformerConfig(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
hidden_size=hf_config.hidden_size,
|
||||
num_attention_heads=hf_config.num_attention_heads,
|
||||
num_query_groups=hf_config.num_key_value_heads,
|
||||
ffn_hidden_size=hf_config.intermediate_size,
|
||||
# max_position_embeddings=hf_config.max_position_embeddings,
|
||||
activation_func=F.silu,
|
||||
normalization='RMSNorm',
|
||||
# rotary_percent=False, # default,
|
||||
gated_linear_unit=True, # for llama
|
||||
use_cpu_initialization=True,
|
||||
apply_residual_connection_post_layernorm=False, # check what's this mean
|
||||
add_bias_linear=False,
|
||||
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
|
||||
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
|
||||
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
|
||||
pipeline_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
|
||||
params_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']),
|
||||
sequence_parallel=megatron_config['sequence_parallel_enabled'],
|
||||
variable_seq_lengths=True,
|
||||
masked_softmax_fusion=True,
|
||||
bf16=PrecisionType.to_dtype(megatron_config['param_dtype']) is torch.bfloat16)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(f'tensor_parallel_size={transformer_config.tensor_model_parallel_size} \n \
|
||||
pipeline_model_parallel_size={transformer_config.pipeline_model_parallel_size} \n \
|
||||
virtual_pipeline_model_parallel_size={transformer_config.virtual_pipeline_model_parallel_size} \n \
|
||||
pipeline_dtype={transformer_config.pipeline_dtype} \n \
|
||||
params_dtype={transformer_config.params_dtype} \n \
|
||||
sequence_parallel={transformer_config.sequence_parallel} \n \
|
||||
variable_seq_lengths={transformer_config.variable_seq_lengths} \n \
|
||||
masked_softmax_fusion={transformer_config.masked_softmax_fusion} \n ')
|
||||
|
||||
return transformer_config
|
||||
|
||||
|
||||
# from megatron.core.optimizer import OptimizerConfig
|
||||
|
||||
from verl.utils.megatron.optimizer_config import OptimizerConfig
|
||||
|
||||
|
||||
def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
|
||||
config = OptimizerConfig(
|
||||
optimizer='adam',
|
||||
lr=optim_config.get('lr'),
|
||||
clip_grad=optim_config.get('clip_grad'),
|
||||
weight_decay=1e-2,
|
||||
bf16=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
use_distributed_optimizer=True,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
from megatron.core import ModelParallelConfig
|
||||
|
||||
|
||||
def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
|
||||
# TODO(sgm): check how to disable megatron timers
|
||||
timers = FakeTimers()
|
||||
return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'),
|
||||
pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'),
|
||||
virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'),
|
||||
sequence_parallel=config.get('sequence_parallel'),
|
||||
params_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
|
||||
pipeline_dtype=PrecisionType.to_dtype(config.get('param_dtype')),
|
||||
bf16=True,
|
||||
fp16=False,
|
||||
timers=timers)
|
||||
|
||||
|
||||
class FakeTimers:
|
||||
"""Disable All Megatron Timing with FakeTimers"""
|
||||
|
||||
def __init__(self):
|
||||
from megatron.timers import DummyTimer
|
||||
self.dummy_timer = DummyTimer()
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.dummy_timer
|
||||
|
||||
|
||||
def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None):
|
||||
if hybrid_engine is not None:
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
|
||||
buffer.data = buffer.data.to('cpu', non_blocking=True)
|
||||
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
|
||||
else:
|
||||
for module in module_list:
|
||||
for _, param in module.named_parameters():
|
||||
param.data = param.data.to('cpu', non_blocking=True)
|
||||
if offload_grad and param.grad is not None:
|
||||
param.grad = param.grad.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_megatron_param_and_grad(module_list: nn.ModuleList, device_id, load_grad=False, hybrid_engine=None):
|
||||
if hybrid_engine is not None:
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
for buffer in hybrid_engine.memory_buffers[pp_rank].values():
|
||||
buffer.data = buffer.data.to(device_id, non_blocking=True)
|
||||
build_memory_reference_from_module(module_list, hybrid_engine.memory_buffers[pp_rank], maintain_weight=True)
|
||||
else:
|
||||
for module in module_list:
|
||||
for _, param in module.named_parameters():
|
||||
param.data = param.data.to(device_id, non_blocking=True)
|
||||
if load_grad and param.grad is not None:
|
||||
param.grad = param.grad.to(device_id, non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
214
verl/utils/memory_buffer.py
Normal file
214
verl/utils/memory_buffer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains utilities to manipulate torch memory buffers
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
"""
|
||||
A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying
|
||||
memory. It must have a unique type to support this behavior.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype):
|
||||
self.numel = numel
|
||||
self.numel_padded = numel_padded
|
||||
self.dtype = dtype
|
||||
self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False)
|
||||
|
||||
def zero(self):
|
||||
"""Reset the buffer to zero."""
|
||||
self.data.zero_()
|
||||
|
||||
def get(self, shape, start_index):
|
||||
"""Return a tensor with the input `shape` as a view into the
|
||||
1-D data starting at `start_index`."""
|
||||
end_index = start_index + shape.numel()
|
||||
assert end_index <= self.numel, \
|
||||
'requested tensor is out of the buffer range.'
|
||||
buffer_tensor = self.data[start_index:end_index]
|
||||
buffer_tensor = buffer_tensor.view(shape)
|
||||
return buffer_tensor
|
||||
|
||||
|
||||
def calc_padded_numel(shape: torch.Size, dtype: torch.dtype):
|
||||
"""for cuda memory alignment, make sure alignment by 128-bits"""
|
||||
align_numel = 128 // torch.finfo(dtype).bits
|
||||
numel = shape.numel()
|
||||
return (numel + align_numel - 1) // align_numel * align_numel
|
||||
|
||||
|
||||
def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]:
|
||||
"""
|
||||
Return a dictionary containing name to a shape and dtype.
|
||||
"""
|
||||
weight_buffer_meta = {}
|
||||
for name, param in sorted(module.named_parameters()):
|
||||
weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype}
|
||||
return weight_buffer_meta
|
||||
|
||||
|
||||
def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]:
|
||||
"""Build the memory buffer given weight_buffer_meta
|
||||
|
||||
Args:
|
||||
weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors
|
||||
|
||||
Returns: a large memory buffer for each dtype that can hold all the tensors
|
||||
|
||||
"""
|
||||
memory_buffers = {}
|
||||
total_numel_map = {} # map from dtype to the total numel
|
||||
for name, meta_info in sorted(weight_buffer_meta.items()):
|
||||
shape = meta_info['shape']
|
||||
dtype = meta_info['dtype']
|
||||
|
||||
assert isinstance(shape, torch.Size)
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
|
||||
if dtype not in total_numel_map:
|
||||
total_numel_map[dtype] = 0
|
||||
|
||||
total_numel_map[dtype] += calc_padded_numel(shape, dtype)
|
||||
|
||||
for dtype, total_numel in total_numel_map.items():
|
||||
memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)
|
||||
|
||||
return memory_buffers
|
||||
|
||||
|
||||
def build_memory_reference_from_module(module: torch.nn.Module,
|
||||
memory_buffers: Dict[torch.dtype, MemoryBuffer],
|
||||
maintain_weight=True):
|
||||
start_index = {}
|
||||
for dtype in memory_buffers.keys():
|
||||
start_index[dtype] = 0
|
||||
for name, param in sorted(module.named_parameters()):
|
||||
memory_buffer = memory_buffers[param.dtype]
|
||||
buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])
|
||||
# need to increment start_index
|
||||
start_index[param.dtype] += calc_padded_numel(param.shape, dtype)
|
||||
if maintain_weight:
|
||||
buffer.copy_(param.data)
|
||||
param.data = buffer
|
||||
|
||||
|
||||
def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):
|
||||
"""Build the memory references. The memory buffers are built using the build_memory_buffer API.
|
||||
This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.
|
||||
|
||||
Args:
|
||||
weight_buffer_meta:
|
||||
memory_buffers:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
start_idx = {}
|
||||
weight_buffers = {}
|
||||
for dtype in memory_buffers.keys():
|
||||
start_idx[dtype] = 0
|
||||
|
||||
for name, meta_info in sorted(weight_buffer_meta.items()):
|
||||
shape = meta_info['shape']
|
||||
dtype = meta_info['dtype']
|
||||
|
||||
buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])
|
||||
start_idx[dtype] += calc_padded_numel(shape, dtype)
|
||||
weight_buffers[name] = buffer
|
||||
|
||||
return weight_buffers
|
||||
|
||||
|
||||
class MemoryBufferModuleWrapper:
|
||||
"""
|
||||
Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to
|
||||
- It will change the checkpoint name
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)
|
||||
self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)
|
||||
build_memory_reference_from_module(self.module, self.memory_buffers)
|
||||
|
||||
def get_memory_buffers(self):
|
||||
return self.memory_buffers
|
||||
|
||||
def get_weight_buffer_meta(self):
|
||||
return self.weight_buffer_meta
|
||||
|
||||
|
||||
class MegatronMemoryBufferForRollout(object):
|
||||
"""
|
||||
We assume that
|
||||
- inference engine has tp + dp
|
||||
- actor has tp + pp + dp
|
||||
- the tp between inference engine and actor should be the same
|
||||
- memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer
|
||||
- weight_buffers: contains a list of weight_buffers, each is a dict from name to param
|
||||
- named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that
|
||||
the named_parameters may not be directly compatible with inference engine. User has to take care of
|
||||
this part such as the layout mismatches. (e.g. qkv transpose)
|
||||
- Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.
|
||||
- When doing weight sync, the data is transfer via memory buffers
|
||||
"""
|
||||
|
||||
def __init__(self, transform_memory_param_fn):
|
||||
self._memory_buffers = []
|
||||
self._weight_buffers = []
|
||||
self._named_parameters = {}
|
||||
self.transform_memory_param_fn = transform_memory_param_fn
|
||||
|
||||
def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]):
|
||||
"""
|
||||
Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct
|
||||
a large buffer for each dtype in the weight_buffer.
|
||||
|
||||
Args:
|
||||
weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from
|
||||
|
||||
Returns: None
|
||||
|
||||
"""
|
||||
self.weight_buffer_meta_pp = weight_buffer_meta_pp
|
||||
|
||||
for weight_buffer_meta in self.weight_buffer_meta_pp:
|
||||
memory_buffer = build_memory_buffer(weight_buffer_meta)
|
||||
self._memory_buffers.append(memory_buffer)
|
||||
self._weight_buffers.append(None)
|
||||
|
||||
def build_memory_reference(self):
|
||||
for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):
|
||||
self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])
|
||||
self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)
|
||||
|
||||
@property
|
||||
def named_parameters(self):
|
||||
return self._named_parameters
|
||||
|
||||
@property
|
||||
def weight_buffers(self):
|
||||
return self._weight_buffers
|
||||
|
||||
@property
|
||||
def memory_buffers(self):
|
||||
return self._memory_buffers
|
||||
332
verl/utils/model.py
Normal file
332
verl/utils/model.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities to create common models from huggingface
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification
|
||||
from verl.models.registry import ModelRegistry
|
||||
|
||||
|
||||
class LambdaLayer(nn.Module):
|
||||
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def squeeze(x):
|
||||
return torch.squeeze(x, dim=-1)
|
||||
|
||||
|
||||
def update_model_config(module_config, override_config_kwargs):
|
||||
for key, val in override_config_kwargs.items():
|
||||
setattr(module_config, key, val)
|
||||
|
||||
|
||||
def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> Dict:
|
||||
if override_config_kwargs is None:
|
||||
override_config_kwargs = {}
|
||||
assert isinstance(override_config_kwargs, Dict), \
|
||||
f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}'
|
||||
module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
update_model_config(module_config, override_config_kwargs)
|
||||
|
||||
return module_config
|
||||
|
||||
|
||||
def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_name:
|
||||
actor_override_config_kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if override_config_kwargs is None:
|
||||
override_config_kwargs = {}
|
||||
if automodel_kwargs is None:
|
||||
automodel_kwargs = {}
|
||||
assert isinstance(override_config_kwargs, Dict), \
|
||||
f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}'
|
||||
module_config = get_huggingface_actor_config(model_name,
|
||||
override_config_kwargs,
|
||||
trust_remote_code=automodel_kwargs.get('trust_remote_code', False))
|
||||
module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)
|
||||
return module
|
||||
|
||||
|
||||
def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_name:
|
||||
override_config_kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
critic_module: nn.Module = create_huggingface_actor(model_name,
|
||||
override_config_kwargs=override_config_kwargs,
|
||||
automodel_kwargs=automodel_kwargs)
|
||||
if automodel_kwargs is None:
|
||||
automodel_kwargs = {}
|
||||
torch_dtype = automodel_kwargs.get('torch_dtype', torch.float32)
|
||||
critic_module.lm_head = nn.Sequential(nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype),
|
||||
LambdaLayer(fn=squeeze))
|
||||
return critic_module
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module, scale='auto'):
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
if scale == 'auto':
|
||||
if n_params > 1e9:
|
||||
scale = 'B'
|
||||
elif n_params > 1e6:
|
||||
scale = 'M'
|
||||
elif n_params > 1e3:
|
||||
scale = 'K'
|
||||
else:
|
||||
scale = ''
|
||||
|
||||
if scale == 'B':
|
||||
n_params = n_params / 1e9
|
||||
elif scale == 'M':
|
||||
n_params = n_params / 1e6
|
||||
elif scale == 'K':
|
||||
n_params = n_params / 1e3
|
||||
elif scale == '':
|
||||
pass
|
||||
else:
|
||||
raise NotImplemented(f'Unknown scale {scale}')
|
||||
|
||||
return n_params, scale
|
||||
|
||||
|
||||
def print_model_size(model: nn.Module, name: str = None):
|
||||
n_params, scale = get_model_size(model, scale='auto')
|
||||
if name is None:
|
||||
name = model.__class__.__name__
|
||||
print(f'{name} contains {n_params:.2f}{scale} parameters')
|
||||
|
||||
|
||||
def create_random_mask(input_ids: torch.Tensor,
|
||||
max_ratio_of_valid_token: float,
|
||||
max_ratio_of_left_padding: float,
|
||||
min_ratio_of_valid_token: float = 0):
|
||||
"""Create a random mask given input_ids. Support left padding and right padding.
|
||||
Process:
|
||||
- Sample valid token length
|
||||
- Sample left_padding length
|
||||
- Generate padding
|
||||
|
||||
Args:
|
||||
input_ids:
|
||||
shape (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.
|
||||
assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.
|
||||
assert min_ratio_of_valid_token <= max_ratio_of_valid_token
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)
|
||||
min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))
|
||||
max_left_padding = int(sequence_length * max_ratio_of_left_padding)
|
||||
assert max_num_valid_tokens + max_left_padding <= sequence_length
|
||||
assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length
|
||||
masks = torch.ones_like(input_ids, dtype=torch.int64)
|
||||
# TODO: we can make this faster
|
||||
for i in range(batch_size):
|
||||
num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)
|
||||
num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)
|
||||
|
||||
for index in range(num_left_padding):
|
||||
masks[i, index] = 0
|
||||
|
||||
for index in range(num_left_padding + num_valid, sequence_length):
|
||||
masks[i, index] = 0
|
||||
return masks
|
||||
|
||||
|
||||
def compute_position_id_with_mask(mask):
|
||||
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
|
||||
|
||||
|
||||
def normalize_pp_vpp_params(params, num_hidden_layers, layer_name='layers'):
|
||||
"""
|
||||
Normalize the pp vpp params into a complete named parameters.
|
||||
This is useful when gather parameters from pp ranks and passed to a model without pp
|
||||
|
||||
params: List[List[Dict[str, param]]]
|
||||
params contains a list of pp, with a list of vpp named_parameters in each vpp chunk.
|
||||
output: Dict[str, param]
|
||||
|
||||
"""
|
||||
|
||||
def normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_layers):
|
||||
"""
|
||||
Transform the model name in each model_chunk in each pp stage into the name in inference engine
|
||||
"""
|
||||
if vpp_size > 1:
|
||||
# print(f'try to bind vpp params to inference engine...')
|
||||
layers_per_pp = num_layers // pp_size
|
||||
layers_per_vpp = layers_per_pp // vpp_size
|
||||
pp_offset = layers_per_vpp * pp_rank
|
||||
vpp_offset = (layers_per_vpp * pp_size) * vpp_rank
|
||||
layer_offset = pp_offset + vpp_offset
|
||||
else:
|
||||
layers_per_pp = num_layers // pp_size
|
||||
layer_offset = layers_per_pp * pp_rank
|
||||
|
||||
if layer_name in name: # belong to an intermediate layer
|
||||
split_name = name.split('.')
|
||||
# find the num next to split_name
|
||||
for i, name in enumerate(split_name):
|
||||
if name == layer_name:
|
||||
break
|
||||
layer_num_idx = i + 1
|
||||
# check the name
|
||||
assert len(split_name) >= layer_num_idx + 1, f'split_name = {split_name}'
|
||||
assert split_name[layer_num_idx].isdigit(), f'split_name = {split_name}'
|
||||
# increment layer_num_idx by layer_offset
|
||||
split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)
|
||||
name = '.'.join(split_name) # weight name in inference_tp_model
|
||||
return name
|
||||
|
||||
pp_size = len(params)
|
||||
normalized_name_to_param = {}
|
||||
for pp_rank in range(len(params)):
|
||||
vpp_size = len(params[pp_rank])
|
||||
for vpp_rank in range(vpp_size):
|
||||
for name, param in params[pp_rank][vpp_rank].items():
|
||||
normalized_name = normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers)
|
||||
normalized_name_to_param[normalized_name] = param
|
||||
|
||||
return normalized_name_to_param
|
||||
|
||||
|
||||
def get_parallel_model_from_config(config, megatron_config, pre_process=None, post_process=None, value=False):
|
||||
from megatron.core import ModelParallelConfig
|
||||
assert isinstance(megatron_config, ModelParallelConfig)
|
||||
model_class = _get_parallel_model_architecture_from_config(config, value)
|
||||
|
||||
model = model_class(config, megatron_config, pre_process=pre_process, post_process=post_process)
|
||||
return model
|
||||
|
||||
|
||||
def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch, value)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def load_megatron_model_weights(config,
|
||||
model_config,
|
||||
parallel_model,
|
||||
params_dtype,
|
||||
is_value_model=False,
|
||||
local_cache_path='~/.cache/verl/rlhf'):
|
||||
assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!"
|
||||
architectures = getattr(model_config, "architectures", [])
|
||||
local_cache_path = os.path.expanduser(local_cache_path)
|
||||
|
||||
if config.model.path.startswith("hdfs:"):
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
print(f'start download from {config.model.path}')
|
||||
local_model_path = copy_local_path_from_hdfs(src=config.model.path, cache_dir=local_cache_path)
|
||||
print('finish download')
|
||||
else:
|
||||
print(f"load from local dir {config.model.path}")
|
||||
local_model_path = config.model.path
|
||||
|
||||
# TODO: to find a better way to load mistral7b-rm lm_head
|
||||
if 'mistral7b-rm' in config.model.path:
|
||||
model = MistralForSequenceClassification.from_pretrained(local_model_path) # use score head instead of lm_head
|
||||
state_dict = model.state_dict()
|
||||
state_dict['lm_head.weight'] = state_dict['score.weight']
|
||||
state_dict['model.embed_tokens.weight'] = state_dict[
|
||||
'model.embed_tokens.weight'][:32000] # workaround, 32001 -> 32000
|
||||
is_value_model = True
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model = AutoModelForCausalLM.from_pretrained(local_model_path)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
from verl.models.weight_loader_registry import get_weight_loader
|
||||
print(f'before weight loader: architectures = {architectures}...')
|
||||
for arch in architectures:
|
||||
print(f'call weight loader arch = {arch}, model config = {model.config}')
|
||||
weight_loader = get_weight_loader(arch)
|
||||
weight_loader(state_dict=state_dict,
|
||||
wrapped_models=parallel_model,
|
||||
config=model.config,
|
||||
params_dtype=params_dtype,
|
||||
is_value_model=is_value_model)
|
||||
|
||||
|
||||
# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp
|
||||
def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):
|
||||
"""pad the tokens such that the total length is a multiple of size.
|
||||
This function is useful when applying sequence parallel and context parallel
|
||||
|
||||
Args:
|
||||
unpad_tokens: (total_nnz, ...). Tokens after removing padding
|
||||
cu_seqlens: (total_nnz + 1,)
|
||||
max_seqlen_in_batch: int
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
F = nn.functional
|
||||
|
||||
total_nnz = unpad_tokens.shape[0]
|
||||
|
||||
if total_nnz % size == 0:
|
||||
pad_size = 0
|
||||
else:
|
||||
pad_size = size - total_nnz % size
|
||||
|
||||
# we assume adding a new data in the batch with seqlen pad_size
|
||||
if pad_size > 0:
|
||||
if unpad_tokens.ndim == 1:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
|
||||
elif unpad_tokens.ndim == 2:
|
||||
unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
|
||||
else:
|
||||
raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
|
||||
|
||||
cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])
|
||||
max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)
|
||||
|
||||
return unpad_tokens, cu_seqlens, max_seqlen_in_batch
|
||||
56
verl/utils/py_functional.py
Normal file
56
verl/utils/py_functional.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contain small python utility functions
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def union_two_dict(dict1: Dict, dict2: Dict):
|
||||
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
|
||||
|
||||
Args:
|
||||
dict1:
|
||||
dict2:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
for key, val in dict2.items():
|
||||
if key in dict1:
|
||||
assert dict2[key] == dict1[key], \
|
||||
f'{key} in meta_dict1 and meta_dict2 are not the same object'
|
||||
dict1[key] = val
|
||||
|
||||
return dict1
|
||||
|
||||
|
||||
def append_to_dict(data: Dict, new_data: Dict):
|
||||
for key, val in new_data.items():
|
||||
if key not in data:
|
||||
data[key] = []
|
||||
data[key].append(val)
|
||||
|
||||
|
||||
class NestedNamespace(SimpleNamespace):
|
||||
|
||||
def __init__(self, dictionary, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
self.__setattr__(key, NestedNamespace(value))
|
||||
else:
|
||||
self.__setattr__(key, value)
|
||||
43
verl/utils/ray_utils.py
Normal file
43
verl/utils/ray_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains commonly used utilities for ray
|
||||
"""
|
||||
|
||||
import ray
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def parallel_put(data_list, max_workers=None):
|
||||
|
||||
def put_data(index, data):
|
||||
return index, ray.put(data)
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = min(len(data_list), 16)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]
|
||||
res_lst = []
|
||||
for future in concurrent.futures.as_completed(data_list_f):
|
||||
res_lst.append(future.result())
|
||||
|
||||
# reorder based on index
|
||||
output = [None for _ in range(len(data_list))]
|
||||
for res in res_lst:
|
||||
index, data_ref = res
|
||||
output[index] = data_ref
|
||||
|
||||
return output
|
||||
13
verl/utils/rendezvous/__init__.py
Normal file
13
verl/utils/rendezvous/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
77
verl/utils/rendezvous/ray_backend.py
Normal file
77
verl/utils/rendezvous/ray_backend.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from cupy.cuda.nccl import NcclCommunicator, get_unique_id
|
||||
|
||||
import ray
|
||||
from ray.util import list_named_actors
|
||||
|
||||
|
||||
@ray.remote
|
||||
class NCCLIDStore:
|
||||
|
||||
def __init__(self, nccl_id):
|
||||
self._nccl_id = nccl_id
|
||||
|
||||
def get(self):
|
||||
return self._nccl_id
|
||||
|
||||
|
||||
def get_nccl_id_store_by_name(name):
|
||||
all_actors = list_named_actors(all_namespaces=True)
|
||||
matched_actors = [actor for actor in all_actors if actor.get("name", None) == name]
|
||||
if len(matched_actors) == 1:
|
||||
actor = matched_actors[0]
|
||||
return ray.get_actor(**actor)
|
||||
elif len(matched_actors) > 1:
|
||||
logging.warning(f"multiple actors with same name found: {matched_actors}")
|
||||
elif len(matched_actors) == 0:
|
||||
logging.info(f"failed to get any actor named {name}")
|
||||
return None
|
||||
|
||||
|
||||
def create_nccl_communicator_in_ray(rank: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
max_retries: int = 100,
|
||||
interval_s: int = 5):
|
||||
if rank == 0:
|
||||
nccl_id = get_unique_id()
|
||||
nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)
|
||||
|
||||
assert ray.get(nccl_id_store.get.remote()) == nccl_id
|
||||
communicator = NcclCommunicator(
|
||||
ndev=world_size,
|
||||
commId=nccl_id,
|
||||
rank=0,
|
||||
)
|
||||
return communicator
|
||||
else:
|
||||
for i in range(max_retries):
|
||||
nccl_id_store = get_nccl_id_store_by_name(group_name)
|
||||
if nccl_id_store is not None:
|
||||
logging.info(f"nccl_id_store {group_name} got")
|
||||
nccl_id = ray.get(nccl_id_store.get.remote())
|
||||
logging.info(f"nccl id for {group_name} got: {nccl_id}")
|
||||
communicator = NcclCommunicator(
|
||||
ndev=world_size,
|
||||
commId=nccl_id,
|
||||
rank=rank,
|
||||
)
|
||||
return communicator
|
||||
logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds")
|
||||
time.sleep(interval_s)
|
||||
13
verl/utils/reward_score/__init__.py
Normal file
13
verl/utils/reward_score/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
111
verl/utils/reward_score/countdown.py
Normal file
111
verl/utils/reward_score/countdown.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import re
|
||||
import random
|
||||
import ast
|
||||
import operator
|
||||
|
||||
|
||||
def extract_solution(solution_str):
|
||||
"""Extract the equation from the solution string."""
|
||||
# Remove everything before the first "Assistant:"
|
||||
if "Assistant:" in solution_str:
|
||||
solution_str = solution_str.split("Assistant:", 1)[1]
|
||||
elif "<|im_start|>assistant" in solution_str:
|
||||
solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
|
||||
else:
|
||||
return None
|
||||
solution_str = solution_str.split('\n')[-1]
|
||||
|
||||
answer_pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.finditer(answer_pattern, solution_str)
|
||||
matches = list(match)
|
||||
if matches:
|
||||
final_answer = matches[-1].group(1).strip()
|
||||
else:
|
||||
final_answer = None
|
||||
return final_answer
|
||||
|
||||
|
||||
def validate_equation(equation_str, available_numbers):
|
||||
"""Validate that equation only uses available numbers and each number once."""
|
||||
try:
|
||||
# Extract all numbers from the equation
|
||||
numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)]
|
||||
|
||||
# Check if all numbers in equation are available
|
||||
available_numbers = sorted(available_numbers)
|
||||
numbers_in_eq = sorted(numbers_in_eq)
|
||||
|
||||
# Each number should be used exactly once
|
||||
return numbers_in_eq == available_numbers
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def evaluate_equation(equation_str):
|
||||
"""Safely evaluate the arithmetic equation using eval() with precautions."""
|
||||
try:
|
||||
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
|
||||
allowed_pattern = r'^[\d+\-*/().\s]+$'
|
||||
if not re.match(allowed_pattern, equation_str):
|
||||
raise ValueError("Invalid characters in equation.")
|
||||
|
||||
# Evaluate the equation with restricted globals and locals
|
||||
result = eval(equation_str, {"__builtins__": None}, {})
|
||||
return result
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
|
||||
"""The scoring function for countdown task.
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: dictionary containing target number and available numbers
|
||||
method: the method to extract the solution
|
||||
format_score: the score for correct format but wrong answer
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
target = ground_truth['target']
|
||||
numbers = ground_truth['numbers']
|
||||
|
||||
equation = extract_solution(solution_str=solution_str)
|
||||
do_print = random.randint(1, 64) == 1
|
||||
|
||||
if do_print:
|
||||
print(f"--------------------------------")
|
||||
print(f"Target: {target} | Numbers: {numbers}")
|
||||
print(f"Extracted equation: {equation}")
|
||||
print(f"Solution string: {solution_str}")
|
||||
|
||||
if equation is None:
|
||||
if do_print:
|
||||
print(f"No equation found")
|
||||
return 0
|
||||
|
||||
# Validate equation uses correct numbers
|
||||
if not validate_equation(equation, numbers):
|
||||
if do_print:
|
||||
print(f"Invalid equation")
|
||||
return format_score
|
||||
|
||||
# Evaluate equation
|
||||
try:
|
||||
result = evaluate_equation(equation)
|
||||
if result is None:
|
||||
if do_print:
|
||||
print(f"Could not evaluate equation")
|
||||
return format_score
|
||||
|
||||
if abs(result - target) < 1e-5: # Account for floating point precision
|
||||
if do_print:
|
||||
print(f"Correct equation: {equation} = {result}")
|
||||
return score
|
||||
else:
|
||||
if do_print:
|
||||
print(f"Wrong result: equation = {result}, target = {target}")
|
||||
return format_score
|
||||
except:
|
||||
if do_print:
|
||||
print(f"Error evaluating equation")
|
||||
return format_score
|
||||
63
verl/utils/reward_score/gsm8k.py
Normal file
63
verl/utils/reward_score/gsm8k.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_solution(solution_str, method='strict'):
|
||||
assert method in ['strict', 'flexible']
|
||||
|
||||
if method == 'strict':
|
||||
# this also tests the formatting of the model
|
||||
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if solution is None:
|
||||
final_answer = None
|
||||
else:
|
||||
final_answer = solution.group(0)
|
||||
final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
|
||||
elif method == 'flexible':
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ['', '.']
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
|
||||
"""The scoring function for GSM8k.
|
||||
|
||||
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: the ground truth
|
||||
method: the method to extract the solution, choices are 'strict' and 'flexible'
|
||||
format_score: the score for the format
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
answer = extract_solution(solution_str=solution_str, method=method)
|
||||
if answer is None:
|
||||
return 0
|
||||
else:
|
||||
if answer == ground_truth:
|
||||
return score
|
||||
else:
|
||||
return format_score
|
||||
227
verl/utils/reward_score/math.py
Normal file
227
verl/utils/reward_score/math.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
|
||||
|
||||
|
||||
def compute_score(solution_str, ground_truth) -> float:
|
||||
retval = 0.
|
||||
try:
|
||||
string_in_last_boxed = last_boxed_only_string(solution_str)
|
||||
if string_in_last_boxed is not None:
|
||||
answer = remove_boxed(string_in_last_boxed)
|
||||
if is_equiv(answer, ground_truth):
|
||||
retval = 1.
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
|
||||
def is_equiv(str1, str2, verbose=False):
|
||||
if str1 is None and str2 is None:
|
||||
print("WARNING: Both None")
|
||||
return True
|
||||
if str1 is None or str2 is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ss1 = strip_string(str1)
|
||||
ss2 = strip_string(str2)
|
||||
if verbose:
|
||||
print(ss1, ss2)
|
||||
return ss1 == ss2
|
||||
except Exception:
|
||||
return str1 == str2
|
||||
|
||||
|
||||
def remove_boxed(s):
|
||||
if "\\boxed " in s:
|
||||
left = "\\boxed "
|
||||
assert s[:len(left)] == left
|
||||
return s[len(left):]
|
||||
|
||||
left = "\\boxed{"
|
||||
|
||||
assert s[:len(left)] == left
|
||||
assert s[-1] == "}"
|
||||
|
||||
return s[len(left):-1]
|
||||
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind("\\boxed")
|
||||
if "\\boxed " in string:
|
||||
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx:right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def fix_sqrt(string):
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "") # noqa: W605
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
58
verl/utils/reward_score/multiply.py
Normal file
58
verl/utils/reward_score/multiply.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import re
|
||||
import random
|
||||
|
||||
|
||||
def extract_solution(solution_str):
|
||||
# Remove everything before the first "Assistant:"
|
||||
if "Assistant:" in solution_str:
|
||||
solution_str = solution_str.split("Assistant:", 1)[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
answer_pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.finditer(answer_pattern, solution_str)
|
||||
matches = list(match)
|
||||
if matches:
|
||||
final_answer = matches[-1].group(1).strip()
|
||||
else:
|
||||
final_answer = None
|
||||
if final_answer is not None:
|
||||
try:
|
||||
int_final_answer = int(final_answer)
|
||||
except ValueError:
|
||||
final_answer = None
|
||||
return final_answer
|
||||
|
||||
|
||||
def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.):
|
||||
"""The scoring function for GSM8k.
|
||||
|
||||
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: the ground truth
|
||||
method: the method to extract the solution, choices are 'strict' and 'flexible'
|
||||
format_score: the score for the format
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
answer = extract_solution(solution_str=solution_str)
|
||||
do_print = random.randint(1, 64) == 1
|
||||
if do_print:
|
||||
print(f"--------------------------------")
|
||||
print(f"Ground truth: {ground_truth} | Extracted answer: {answer}")
|
||||
print(f"Solution string: {solution_str}")
|
||||
|
||||
if answer is None:
|
||||
if do_print:
|
||||
print(f"No answer found")
|
||||
return 0
|
||||
else:
|
||||
if int(answer) == int(ground_truth):
|
||||
if do_print:
|
||||
print(f"Correct answer: {answer}")
|
||||
return score
|
||||
else:
|
||||
if do_print:
|
||||
print(f"Incorrect answer {answer} | Ground truth: {ground_truth}")
|
||||
return format_score
|
||||
138
verl/utils/reward_score/qa_em.py
Normal file
138
verl/utils/reward_score/qa_em.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import string
|
||||
import random
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def em_check(prediction, golden_answers):
|
||||
if isinstance(golden_answers, str):
|
||||
golden_answers = [golden_answers]
|
||||
normalized_prediction = normalize_answer(prediction)
|
||||
score = 0
|
||||
for golden_answer in golden_answers:
|
||||
golden_answer = normalize_answer(golden_answer)
|
||||
if golden_answer == normalized_prediction:
|
||||
score = 1
|
||||
break
|
||||
return score
|
||||
|
||||
|
||||
def subem_check(prediction, golden_answers):
|
||||
if isinstance(golden_answers, str):
|
||||
golden_answers = [golden_answers]
|
||||
normalized_prediction = normalize_answer(prediction)
|
||||
score = 0
|
||||
for golden_answer in golden_answers:
|
||||
golden_answer = normalize_answer(golden_answer)
|
||||
if golden_answer in normalized_prediction:
|
||||
score = 1
|
||||
break
|
||||
return score
|
||||
|
||||
|
||||
def extract_solution(solution_str):
|
||||
"""Extract the equation from the solution string."""
|
||||
# Remove everything before the first "Assistant:"
|
||||
# if "Assistant:" in solution_str:
|
||||
# solution_str = solution_str.split("Assistant:", 1)[1]
|
||||
# elif "<|im_start|>assistant" in solution_str:
|
||||
# solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
|
||||
# else:
|
||||
# return None
|
||||
# solution_str = solution_str.split('\n')[-1]
|
||||
|
||||
answer_pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.finditer(answer_pattern, solution_str, re.DOTALL)
|
||||
matches = list(match)
|
||||
|
||||
# If there are 0 or exactly 1 matches, return None
|
||||
if len(matches) <= 1:
|
||||
return None
|
||||
|
||||
# If there are 2 or more matches, return the last one
|
||||
return matches[-1].group(1).strip()
|
||||
|
||||
|
||||
def compute_score_em(solution_str, ground_truth, method='strict', format_score=0., score=1.):
|
||||
"""The scoring function for exact match (EM).
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: the ground truth
|
||||
method: the method to extract the solution, choices are 'strict' and 'flexible'
|
||||
format_score: the score for the format
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
answer = extract_solution(solution_str=solution_str)
|
||||
do_print = random.randint(1, 64) == 1
|
||||
|
||||
if do_print:
|
||||
print(f"--------------------------------")
|
||||
print(f"Golden answers: {ground_truth['target']}")
|
||||
print(f"Extracted answer: {answer}")
|
||||
print(f"Solution string: {solution_str}")
|
||||
|
||||
if answer is None:
|
||||
return 0
|
||||
else:
|
||||
if em_check(answer, ground_truth['target']):
|
||||
return score
|
||||
else:
|
||||
return format_score
|
||||
|
||||
|
||||
def compute_score_subem(solution_str, ground_truth, method='strict', format_score=0., score=1.):
|
||||
"""The scoring function for substring exact match (EM).
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: the ground truth
|
||||
method: the method to extract the solution, choices are 'strict' and 'flexible'
|
||||
format_score: the score for the format
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
answer = extract_solution(solution_str=solution_str)
|
||||
do_print = random.randint(1, 64) == 1
|
||||
|
||||
if do_print:
|
||||
print(f"--------------------------------")
|
||||
print(f"Golden answers: {ground_truth['target']}")
|
||||
print(f"Extracted answer: {answer}")
|
||||
print(f"Solution string: {solution_str}")
|
||||
|
||||
if answer is None:
|
||||
return 0
|
||||
else:
|
||||
if subem_check(answer, ground_truth['target']):
|
||||
return score
|
||||
else:
|
||||
return format_score
|
||||
265
verl/utils/seqlen_balancing.py
Normal file
265
verl/utils/seqlen_balancing.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple, Callable
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from tensordict import TensorDict
|
||||
import copy
|
||||
|
||||
|
||||
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
||||
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
|
||||
class Set:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sum = 0
|
||||
self.items = []
|
||||
|
||||
def add(self, idx: int, val: int):
|
||||
self.items.append((idx, val))
|
||||
self.sum += val
|
||||
|
||||
def merge(self, other):
|
||||
for idx, val in other.items:
|
||||
self.items.append((idx, val))
|
||||
self.sum += val
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.sum != other.sum:
|
||||
return self.sum < other.sum
|
||||
if len(self.items) != len(other.items):
|
||||
return len(self.items) < len(other.items)
|
||||
return self.items < other.items
|
||||
|
||||
class State:
|
||||
|
||||
def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
|
||||
self.k = k
|
||||
# sets should always be decreasing order
|
||||
self.sets = [Set() for _ in range(k)]
|
||||
assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
|
||||
for i, (idx, seqlen) in enumerate(items):
|
||||
self.sets[i].add(idx=idx, val=seqlen)
|
||||
self.sets = sorted(self.sets, reverse=True)
|
||||
|
||||
def spread(self):
|
||||
return self.sets[0].sum - self.sets[-1].sum
|
||||
|
||||
def get_partitions(self):
|
||||
partitions = []
|
||||
for i in range(len(self.sets)):
|
||||
cur_partition = []
|
||||
for idx, _ in self.sets[i].items:
|
||||
cur_partition.append(idx)
|
||||
partitions.append(cur_partition)
|
||||
return partitions
|
||||
|
||||
def merge(self, other):
|
||||
for i in range(self.k):
|
||||
self.sets[i].merge(other.sets[self.k - 1 - i])
|
||||
self.sets = sorted(self.sets, reverse=True)
|
||||
|
||||
@property
|
||||
def spread(self) -> int:
|
||||
return self.sets[0].sum - self.sets[-1].sum
|
||||
|
||||
def __lt__(self, other):
|
||||
# least heap, let the state with largest spread to be popped first,
|
||||
# if the spread is the same, let the state who has the largest set
|
||||
# to be popped first.
|
||||
if self.spread != other.spread:
|
||||
return self.spread > other.spread
|
||||
return self.sets[0] > other.sets[0]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = "["
|
||||
for i in range(self.k):
|
||||
if i > 0:
|
||||
repr_str += ","
|
||||
repr_str += "{"
|
||||
for j, (_, seqlen) in enumerate(self.sets[i].items):
|
||||
if j > 0:
|
||||
repr_str += ","
|
||||
repr_str += str(seqlen)
|
||||
repr_str += "}"
|
||||
repr_str += "]"
|
||||
return repr_str
|
||||
|
||||
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
|
||||
states_pq = []
|
||||
if equal_size:
|
||||
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
|
||||
for offset in range(0, len(sorted_seqlen_list), k_partitions):
|
||||
items = []
|
||||
for i in range(k_partitions):
|
||||
seqlen, idx = sorted_seqlen_list[offset + i]
|
||||
items.append((idx, seqlen))
|
||||
heapq.heappush(states_pq, State(items=items, k=k_partitions))
|
||||
else:
|
||||
for seqlen, idx in sorted_seqlen_list:
|
||||
heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
|
||||
|
||||
while len(states_pq) > 1:
|
||||
state0 = heapq.heappop(states_pq)
|
||||
state1 = heapq.heappop(states_pq)
|
||||
# merge states
|
||||
state0.merge(state1)
|
||||
heapq.heappush(states_pq, state0)
|
||||
|
||||
final_state = states_pq[0]
|
||||
partitions = final_state.get_partitions()
|
||||
if equal_size:
|
||||
for i, partition in enumerate(partitions):
|
||||
assert len(partition) * \
|
||||
k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
||||
return partitions
|
||||
|
||||
|
||||
def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
||||
bias = sum(seqlen_list) + 1 if equal_size else 0
|
||||
sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
|
||||
partitions = [[] for _ in range(k_partitions)]
|
||||
partition_sums = [0 for _ in range(k_partitions)]
|
||||
for seqlen, i in sorted_seqlen:
|
||||
min_idx = None
|
||||
for j in range(k_partitions):
|
||||
if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
|
||||
min_idx = j
|
||||
partitions[min_idx].append(i)
|
||||
partition_sums[min_idx] += seqlen
|
||||
if equal_size:
|
||||
for i, partition in enumerate(partitions):
|
||||
assert len(partition) * \
|
||||
k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
||||
return partitions
|
||||
|
||||
|
||||
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
||||
""" get order of seq lengths to make partitions balanced, this is
|
||||
used in balacing sum of seqlength across dp ranks and microbatches
|
||||
Parameters:
|
||||
seqlen_list (List[int]):
|
||||
seq lengths of each items
|
||||
k_partitions (int):
|
||||
resulting number of partitions
|
||||
equal_size (bool):
|
||||
if True, number of items in each partitions must be equal.
|
||||
if False, only consider balancing the sum, each partition can have
|
||||
variable number of items
|
||||
Returns:
|
||||
partitions (List[List[int]]):
|
||||
return k_partitions list containing the index of items.
|
||||
"""
|
||||
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
|
||||
|
||||
def _check_and_sort_partitions(partitions):
|
||||
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
|
||||
seen_idx = set()
|
||||
sorted_partitions = [None] * k_partitions
|
||||
for i, partition in enumerate(partitions):
|
||||
assert len(partition) > 0, f"the {i}-th partition is empty"
|
||||
for idx in partition:
|
||||
seen_idx.add(idx)
|
||||
sorted_partitions[i] = sorted(partition)
|
||||
assert seen_idx == set(range(len(seqlen_list)))
|
||||
return sorted_partitions
|
||||
|
||||
partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
|
||||
return _check_and_sort_partitions(partitions)
|
||||
|
||||
|
||||
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
|
||||
# add some metrics of seqlen sum on dp ranks
|
||||
k_partition = len(partitions)
|
||||
# assert len(seqlen_list) % k_partition == 0
|
||||
batch_size = len(seqlen_list) // k_partition
|
||||
min_sum_seqlen = None
|
||||
max_sum_seqlen = None
|
||||
total_sum_seqlen = 0
|
||||
for offset in range(0, len(seqlen_list), batch_size):
|
||||
cur_sum_seqlen = sum(seqlen_list[offset:offset + batch_size])
|
||||
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
|
||||
min_sum_seqlen = cur_sum_seqlen
|
||||
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
|
||||
max_sum_seqlen = cur_sum_seqlen
|
||||
total_sum_seqlen += cur_sum_seqlen
|
||||
|
||||
balanced_sum_seqlen_list = []
|
||||
for partition in partitions:
|
||||
cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
|
||||
balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
|
||||
# print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
|
||||
min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
|
||||
max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
|
||||
|
||||
return {
|
||||
f'{prefix}/min': min_sum_seqlen,
|
||||
f'{prefix}/max': max_sum_seqlen,
|
||||
f'{prefix}/minmax_diff': max_sum_seqlen - min_sum_seqlen,
|
||||
f'{prefix}/balanced_min': min_sum_seqlen_balanced,
|
||||
f'{prefix}/balanced_max': max_sum_seqlen_balanced,
|
||||
f'{prefix}/mean': total_sum_seqlen / len(partitions)
|
||||
}
|
||||
|
||||
|
||||
def ceildiv(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
|
||||
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
|
||||
and the number of valid tokens in each micro batch is well balanced.
|
||||
"""
|
||||
# this is per local micro_bsz
|
||||
max_seq_len = batch['attention_mask'].shape[-1]
|
||||
assert max_token_len >= max_seq_len, \
|
||||
f'max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}'
|
||||
|
||||
seq_len_effective: torch.Tensor = batch['attention_mask'].sum(dim=1)
|
||||
total_seqlen = seq_len_effective.sum().item()
|
||||
num_micro_batches = ceildiv(total_seqlen, max_token_len)
|
||||
if dist.is_initialized():
|
||||
num_micro_batches = torch.tensor([num_micro_batches], device='cuda')
|
||||
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
|
||||
num_micro_batches = num_micro_batches.cpu().item()
|
||||
|
||||
seq_len_effective = seq_len_effective.tolist()
|
||||
assert num_micro_batches <= len(seq_len_effective)
|
||||
|
||||
micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
|
||||
|
||||
micro_batches = []
|
||||
|
||||
for partition in micro_bsz_idx:
|
||||
curr_micro_batch = []
|
||||
for idx in partition:
|
||||
curr_micro_batch.append(batch[idx:idx + 1])
|
||||
curr_micro_batch = torch.cat(curr_micro_batch)
|
||||
|
||||
micro_batches.append(curr_micro_batch)
|
||||
|
||||
return micro_batches, micro_bsz_idx
|
||||
|
||||
|
||||
def get_reverse_idx(idx_map):
|
||||
reverse_idx_map = copy.deepcopy(idx_map)
|
||||
|
||||
for i, idx in enumerate(idx_map):
|
||||
reverse_idx_map[idx] = i
|
||||
|
||||
return reverse_idx_map
|
||||
58
verl/utils/tokenizer.py
Normal file
58
verl/utils/tokenizer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utils for tokenization."""
|
||||
import warnings
|
||||
|
||||
__all__ = ['hf_tokenizer']
|
||||
|
||||
|
||||
def set_pad_token_id(tokenizer):
|
||||
"""Set pad_token_id to eos_token_id if it is None.
|
||||
|
||||
Args:
|
||||
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.
|
||||
|
||||
"""
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}')
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}')
|
||||
|
||||
|
||||
def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):
|
||||
"""Create a huggingface pretrained tokenizer.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tokenizer.
|
||||
correct_pad_token (bool): Whether to correct the pad token id.
|
||||
correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.
|
||||
**kwargs: The keyword arguments for the tokenizer.
|
||||
|
||||
Returns:
|
||||
transformers.PreTrainedTokenizer: The pretrained tokenizer.
|
||||
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path:
|
||||
# the EOS token in gemma2 is ambiguious, which may worsen RL performance.
|
||||
# https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
|
||||
warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to <end_of_turn> and 107.')
|
||||
kwargs['eos_token'] = '<end_of_turn>'
|
||||
kwargs['eos_token_id'] = 107
|
||||
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
|
||||
if correct_pad_token:
|
||||
set_pad_token_id(tokenizer)
|
||||
return tokenizer
|
||||
82
verl/utils/torch_dtypes.py
Normal file
82
verl/utils/torch_dtypes.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from Cruise.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Union
|
||||
|
||||
HALF_LIST = [16, "16", "fp16", "float16"]
|
||||
FLOAT_LIST = [32, "32", "fp32", "float32"]
|
||||
BFLOAT_LIST = ["bf16", "bfloat16"]
|
||||
|
||||
|
||||
class PrecisionType(object):
|
||||
"""Type of precision used.
|
||||
|
||||
>>> PrecisionType.HALF == 16
|
||||
True
|
||||
>>> PrecisionType.HALF in (16, "16")
|
||||
True
|
||||
"""
|
||||
|
||||
HALF = "16"
|
||||
FLOAT = "32"
|
||||
FULL = "64"
|
||||
BFLOAT = "bf16"
|
||||
MIXED = "mixed"
|
||||
|
||||
@staticmethod
|
||||
def supported_type(precision: Union[str, int]) -> bool:
|
||||
return any(x == precision for x in PrecisionType)
|
||||
|
||||
@staticmethod
|
||||
def supported_types() -> list[str]:
|
||||
return [x.value for x in PrecisionType]
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision):
|
||||
return precision in HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision):
|
||||
return precision in FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision):
|
||||
return precision in BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision):
|
||||
if precision in HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in FLOAT_LIST:
|
||||
return torch.float32
|
||||
elif precision in BFLOAT_LIST:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError(f"unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def to_str(precision):
|
||||
if precision == torch.float16:
|
||||
return 'fp16'
|
||||
elif precision == torch.float32:
|
||||
return 'fp32'
|
||||
elif precision == torch.bfloat16:
|
||||
return 'bf16'
|
||||
else:
|
||||
raise RuntimeError(f"unexpected precision: {precision}")
|
||||
492
verl/utils/torch_functional.py
Normal file
492
verl/utils/torch_functional.py
Normal file
@@ -0,0 +1,492 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contain small torch utilities
|
||||
"""
|
||||
|
||||
from typing import Dict, Union, List, Optional
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
|
||||
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False
|
||||
|
||||
|
||||
def gather_from_labels(data, label):
|
||||
"""Gather the label from data. The value in label should be [0, vocab_size)
|
||||
|
||||
Args:
|
||||
data: (..., vocab_size)
|
||||
label (torch.IntTensor) : (...,)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)
|
||||
return output
|
||||
|
||||
|
||||
def logprobs_from_logits(logits, labels):
|
||||
"""
|
||||
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
|
||||
"""
|
||||
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
|
||||
batch_dim = logits.shape[:-1]
|
||||
last_dim = logits.shape[-1]
|
||||
logits = logits.reshape(-1, last_dim)
|
||||
labels = labels.reshape(-1)
|
||||
output = logprobs_from_logits_flash_attn(logits, labels)
|
||||
output = output.view(*batch_dim)
|
||||
else:
|
||||
output = logprobs_from_logits_naive(logits, labels)
|
||||
return output
|
||||
|
||||
|
||||
def logprobs_from_logits_flash_attn(logits, labels):
|
||||
output = -cross_entropy_loss(logits, labels)[0]
|
||||
return output
|
||||
|
||||
|
||||
def logprobs_from_logits_naive(logits, labels):
|
||||
logp = F.log_softmax(logits, dim=-1)
|
||||
logpy = gather_from_labels(logp, labels)
|
||||
return logpy
|
||||
|
||||
|
||||
def logprobs_of_labels_v2(logits: torch.FloatTensor, labels):
|
||||
"""
|
||||
A memory efficient implementation of logprobs_from_logits
|
||||
"""
|
||||
assert logits.dtype == torch.float32, 'Using bf16 logits with logprobs_of_labels_v2 may lead to divergence'
|
||||
logprobs_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1))
|
||||
logprobs_labels = logprobs_labels - torch.logsumexp(logits, dim=-1, keepdim=True)
|
||||
return logprobs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def clip_by_value(x, tensor_min, tensor_max):
|
||||
"""
|
||||
Tensor extenstion to torch.clamp
|
||||
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
|
||||
"""
|
||||
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
|
||||
return clipped
|
||||
|
||||
|
||||
def entropy_from_logits(logits: torch.Tensor):
|
||||
"""Calculate entropy from logits."""
|
||||
pd = torch.nn.functional.softmax(logits, dim=-1)
|
||||
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
def masked_sum(values, mask, axis=None):
|
||||
"""Compute mean of tensor with a masked values."""
|
||||
return (values * mask).sum(axis=axis)
|
||||
|
||||
|
||||
def masked_mean(values, mask, axis=None):
|
||||
"""Compute mean of tensor with a masked values."""
|
||||
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
|
||||
|
||||
|
||||
def masked_var(values, mask, unbiased=True):
|
||||
"""Compute variance of tensor with masked values."""
|
||||
mean = masked_mean(values, mask)
|
||||
centered_values = values - mean
|
||||
variance = masked_mean(centered_values**2, mask)
|
||||
if unbiased:
|
||||
mask_sum = mask.sum()
|
||||
if mask_sum == 0:
|
||||
raise ValueError("At least one element in the mask has to be 1.")
|
||||
# note that if mask_sum == 1, then there is a division by zero issue
|
||||
# to avoid it you just need to use a larger minibatch_size
|
||||
if mask_sum == 1:
|
||||
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
|
||||
bessel_correction = mask_sum / (mask_sum - 1)
|
||||
variance = variance * bessel_correction
|
||||
return variance
|
||||
|
||||
|
||||
def masked_whiten(values, mask, shift_mean=True):
|
||||
"""Whiten values with masked values."""
|
||||
mean, var = masked_mean(values, mask), masked_var(values, mask)
|
||||
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
||||
if not shift_mean:
|
||||
whitened += mean
|
||||
return whitened
|
||||
|
||||
|
||||
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
|
||||
'''
|
||||
e.g. end of sentence token=1
|
||||
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
|
||||
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
|
||||
'''
|
||||
eos_mask = response_id.eq(eos_token).long()
|
||||
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
|
||||
eos_mask = torch.logical_not(eos_mask).to(dtype)
|
||||
return eos_mask
|
||||
|
||||
|
||||
def compute_grad_norm(model: nn.Module):
|
||||
total_grad_square = 0
|
||||
total_params = 0
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
total_grad_square += torch.sum(torch.square(param.grad.detach())).item()
|
||||
return total_grad_square
|
||||
|
||||
|
||||
def broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], TensorDict], src, group):
|
||||
"""
|
||||
TODO: optimize this. Technically, we only need one broadcast
|
||||
"""
|
||||
|
||||
for key in tensors.sorted_keys:
|
||||
torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)
|
||||
|
||||
|
||||
def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], size, group, dim=0):
|
||||
"""
|
||||
TODO: optimize this.
|
||||
- We can use async ops
|
||||
- We can use only one allgather
|
||||
Args:
|
||||
tensors:
|
||||
size:
|
||||
group:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if isinstance(tensors, TensorDict):
|
||||
is_tensor_dict = True
|
||||
tensors_as_dict = tensors.to_dict()
|
||||
else:
|
||||
tensors_as_dict = tensors
|
||||
is_tensor_dict = False
|
||||
|
||||
output = {}
|
||||
sorted_keys = sorted(tensors_as_dict.keys())
|
||||
for key in sorted_keys:
|
||||
val = tensors_as_dict[key]
|
||||
output[key] = [torch.empty_like(val) for _ in range(size)]
|
||||
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
|
||||
output[key] = torch.cat(output[key], dim=dim)
|
||||
|
||||
if is_tensor_dict:
|
||||
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]:
|
||||
assert tensors.batch_size[0] % batch_size == 0, \
|
||||
f'input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}'
|
||||
return tensors.split(batch_size)
|
||||
|
||||
|
||||
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
|
||||
"""
|
||||
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
|
||||
input shape: [bs, seq_length]
|
||||
output shape: [bs, max_seq_length]
|
||||
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
|
||||
"""
|
||||
if tensors.shape[-1] >= max_seq_len:
|
||||
return tensors
|
||||
pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])
|
||||
return F.pad(tensors, pad_tuple, 'constant', pad_token_id)
|
||||
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
def tokenize_and_postprocess_data(prompt: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
pad_token_id: int,
|
||||
left_pad=True,
|
||||
truncation='error'):
|
||||
"""
|
||||
input_data is the output from tokenizer.
|
||||
"""
|
||||
assert truncation in ['left', 'right', 'error']
|
||||
|
||||
input_data = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
||||
|
||||
input_ids = input_data['input_ids']
|
||||
attention_mask = input_data['attention_mask']
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
sequence_length = input_ids.shape[-1]
|
||||
if sequence_length < max_length:
|
||||
input_ids = pad_sequence_to_length(input_ids,
|
||||
max_seq_len=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
left_pad=left_pad)
|
||||
attention_mask = pad_sequence_to_length(attention_mask,
|
||||
max_seq_len=max_length,
|
||||
pad_token_id=0,
|
||||
left_pad=left_pad)
|
||||
elif sequence_length > max_length:
|
||||
if truncation == 'left':
|
||||
# actually, left truncation may not be reasonable
|
||||
input_ids = input_ids[:, -max_length:]
|
||||
attention_mask = attention_mask[:, -max_length:]
|
||||
elif truncation == 'right':
|
||||
input_ids = input_ids[:, :max_length]
|
||||
attention_mask = attention_mask[:, :max_length]
|
||||
elif truncation == 'error':
|
||||
raise NotImplementedError(f'{sequence_length=} is larger than {max_length=}')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown truncation method {truncation}')
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
||||
""" Remove the pad token.
|
||||
|
||||
Args:
|
||||
input_ids shape: [bs, seq_length]
|
||||
attention_mask shape: [bs, seq_length]
|
||||
Returns:
|
||||
no_padding_batch(List[List[int]]): contains the rmpad token ids per query.
|
||||
"""
|
||||
no_padding_batch = []
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
no_padding_batch.append((ids[len(ids) - mask.sum():]).cpu().numpy().tolist())
|
||||
return no_padding_batch
|
||||
|
||||
|
||||
def log_probs_from_logits_response(input_ids, logits, response_length):
|
||||
"""Compute the response log_probs from full logits. Note that logits = model(input_ids)
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seqlen]
|
||||
logits: [batch_size, seqlen, vocab_size]
|
||||
|
||||
Returns:
|
||||
response_log_prob:
|
||||
"""
|
||||
response_logits = logits[:, -response_length - 1:-1]
|
||||
response = input_ids[:, -response_length:]
|
||||
response_log_prob = logprobs_from_logits(logits=response_logits, labels=response)
|
||||
return response_log_prob
|
||||
|
||||
|
||||
def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
|
||||
"""Compute the log_probs from logits with rmpad logits and pad input. Note that
|
||||
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
|
||||
logits and input_ids.
|
||||
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
|
||||
for large vocab_size
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seqlen]
|
||||
attention_mask: [batch_size, seqlen]
|
||||
logits_rmpad: [total_nnz, vocab_size]
|
||||
response_length: int
|
||||
"""
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
batch_size, seqlen = input_ids.shape
|
||||
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
|
||||
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
|
||||
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
|
||||
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
|
||||
return output
|
||||
|
||||
|
||||
def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):
|
||||
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
|
||||
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
|
||||
logits and input_ids.
|
||||
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
|
||||
for large vocab_size
|
||||
|
||||
Args:
|
||||
input_ids_rmpad: [1, total_nnz]
|
||||
logits_rmpad: [total_nnz, vocab_size]
|
||||
indices: [total_nnz]
|
||||
batch_size: int
|
||||
seqlen: int
|
||||
response_length: int
|
||||
"""
|
||||
from flash_attn.bert_padding import pad_input
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1]
|
||||
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
|
||||
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
|
||||
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
|
||||
return output
|
||||
|
||||
|
||||
from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)
|
||||
|
||||
|
||||
def post_process_logits(input_ids, logits, temperature, top_k, top_p):
|
||||
if temperature != 1.:
|
||||
logits = logits.div_(temperature) # inplace operation to avoid OOM
|
||||
# TODO: add them back
|
||||
# if top_k is not None and top_k > 0:
|
||||
# logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits)
|
||||
# if top_p is not None and top_p < 1.0 and top_p > 0.0:
|
||||
# logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits)
|
||||
return logits
|
||||
|
||||
|
||||
"""
|
||||
Optimizer related
|
||||
"""
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import math
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float = 0.0,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The minimum lr ratio w.r.t the maximum.
|
||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.
|
||||
coef = (1 - min_lr_ratio) * 0.5
|
||||
intercept = (1 + min_lr_ratio) * 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
||||
return max(0.0, x * coef + intercept)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
|
||||
def lr_lambda(current_step):
|
||||
return min(1, float(current_step) / float(max(1, num_warmup_steps)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
||||
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
||||
combined_attention_mask)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
103
verl/utils/tracking.py
Normal file
103
verl/utils/tracking.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A unified tracking interface that supports logging data to different backend
|
||||
"""
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Dict, Any
|
||||
|
||||
|
||||
class Tracking(object):
|
||||
supported_backend = ['wandb', 'mlflow', 'console']
|
||||
|
||||
def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None):
|
||||
if isinstance(default_backend, str):
|
||||
default_backend = [default_backend]
|
||||
for backend in default_backend:
|
||||
if backend == 'tracking':
|
||||
import warnings
|
||||
warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning)
|
||||
else:
|
||||
assert backend in self.supported_backend, f'{backend} is not supported'
|
||||
|
||||
self.logger = {}
|
||||
|
||||
if 'tracking' in default_backend or 'wandb' in default_backend:
|
||||
import wandb
|
||||
import os
|
||||
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
|
||||
if WANDB_API_KEY:
|
||||
wandb.login(key=WANDB_API_KEY)
|
||||
wandb.init(project=project_name, name=experiment_name, config=config)
|
||||
self.logger['wandb'] = wandb
|
||||
|
||||
if 'mlflow' in default_backend:
|
||||
import mlflow
|
||||
mlflow.start_run(run_name=experiment_name)
|
||||
mlflow.log_params(_compute_mlflow_params_from_objects(config))
|
||||
self.logger['mlflow'] = _MlflowLoggingAdapter()
|
||||
|
||||
if 'console' in default_backend:
|
||||
from verl.utils.logger.aggregate_logger import LocalLogger
|
||||
self.console_logger = LocalLogger(print_to_console=True)
|
||||
self.logger['console'] = self.console_logger
|
||||
|
||||
def log(self, data, step, backend=None):
|
||||
for default_backend, logger_instance in self.logger.items():
|
||||
if backend is None or default_backend in backend:
|
||||
logger_instance.log(data=data, step=step)
|
||||
|
||||
|
||||
class _MlflowLoggingAdapter:
|
||||
|
||||
def log(self, data, step):
|
||||
import mlflow
|
||||
mlflow.log_metrics(metrics=data, step=step)
|
||||
|
||||
|
||||
def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:
|
||||
if params is None:
|
||||
return {}
|
||||
|
||||
return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/')
|
||||
|
||||
|
||||
def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
|
||||
_transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)
|
||||
|
||||
if dataclasses.is_dataclass(x):
|
||||
return _transform(dataclasses.asdict(x))
|
||||
if isinstance(x, dict):
|
||||
return {k: _transform(v) for k, v in x.items()}
|
||||
if isinstance(x, list):
|
||||
if convert_list_to_dict:
|
||||
return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)}
|
||||
else:
|
||||
return [_transform(v) for v in x]
|
||||
if isinstance(x, Path):
|
||||
return str(x)
|
||||
if isinstance(x, Enum):
|
||||
return x.value
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:
|
||||
import pandas as pd
|
||||
ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0]
|
||||
assert isinstance(ans, dict)
|
||||
return ans
|
||||
288
verl/utils/ulysses.py
Normal file
288
verl/utils/ulysses.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for DeepSpeed Ulysses Sequence Parallelism.
|
||||
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509
|
||||
Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||||
"""
|
||||
from typing import Any, Optional, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
||||
"""
|
||||
Set ulysses sequence parallel process group.
|
||||
"""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
||||
"""
|
||||
Get ulysses sequence parallel process group.
|
||||
"""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
||||
"""
|
||||
Get ulysses sequence parallel world size.
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_world_size(group) if group else 1
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
||||
"""
|
||||
Get ulysses sequence parallel rank.
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_rank(group) if group else 0
|
||||
|
||||
|
||||
def gather_seq_scatter_heads(
|
||||
x: Tensor,
|
||||
seq_dim: int,
|
||||
head_dim: int,
|
||||
unpadded_dim_size: int = 0,
|
||||
group: ProcessGroup = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
A func to sync embedding input with alltoall in sequence parallel
|
||||
gather sequence dimension and scatter head dim:
|
||||
e.g. seq_dim: 1, head_dim: 2
|
||||
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
if not group:
|
||||
return x
|
||||
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
||||
x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
|
||||
if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
|
||||
padding_size = x.size(seq_dim) - unpadded_dim_size
|
||||
x = _unpad_tensor(x, seq_dim, padding_size)
|
||||
return x
|
||||
|
||||
|
||||
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
|
||||
"""
|
||||
A func to sync attention result with alltoall in sequence parallel
|
||||
gather head dimension and scatter seq dim:
|
||||
e.g. seq_dim: 1, head_dim: 2
|
||||
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
|
||||
"""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
if not group:
|
||||
return x
|
||||
dim_size = x.size(seq_dim)
|
||||
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
||||
if dim_size % sp_world != 0:
|
||||
padding_size = sp_world - (dim_size % sp_world)
|
||||
x = _pad_tensor(x, seq_dim, padding_size)
|
||||
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
|
||||
|
||||
|
||||
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
||||
shape = list(x.shape)
|
||||
shape[dim] = padding_size
|
||||
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=dim)
|
||||
|
||||
|
||||
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
||||
slc = [slice(None)] * len(x.shape)
|
||||
slc[dim] = slice(0, -padding_size)
|
||||
return x[slc]
|
||||
|
||||
|
||||
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_world_size = dist.get_world_size(group)
|
||||
sp_rank = get_ulysses_sequence_parallel_rank()
|
||||
dim_size = x.size(dim)
|
||||
# pad before slice
|
||||
if padding and dim_size % sp_world_size:
|
||||
padding_size = sp_world_size - (dim_size % sp_world_size)
|
||||
x = _pad_tensor(x, dim, padding_size)
|
||||
# slice the input tensor
|
||||
parts = x.size(dim) // sp_world_size
|
||||
slc = [slice(None)] * len(x.shape)
|
||||
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
|
||||
return x[slc].contiguous()
|
||||
|
||||
|
||||
def all_to_all_tensor(
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
|
||||
def wait():
|
||||
comm.wait()
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
return wait
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_world_size = dist.get_world_size(group=group)
|
||||
output_shape = list(local_tensor.shape)
|
||||
output_shape[0] = output_shape[0] * sp_world_size
|
||||
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)
|
||||
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)
|
||||
return output
|
||||
|
||||
|
||||
class SeqAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
async_op: bool = False,
|
||||
) -> Tensor:
|
||||
ctx.group = group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.async_op = async_op
|
||||
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
||||
if ctx.async_op:
|
||||
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
|
||||
else:
|
||||
input_t = grad_output[0]
|
||||
return (
|
||||
None,
|
||||
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class Gather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
local_tensor: Tensor,
|
||||
gather_dim: int,
|
||||
grad_scaler: bool = True,
|
||||
async_op=False) -> Tensor:
|
||||
ctx.group = group
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.grad_scaler = grad_scaler
|
||||
ctx.async_op = async_op
|
||||
|
||||
sp_world_size = dist.get_world_size(group=group)
|
||||
ctx.sp_world_size = sp_world_size
|
||||
|
||||
sp_rank = dist.get_rank(group=group)
|
||||
ctx.sp_rank = sp_rank
|
||||
|
||||
local_shape = list(local_tensor.size())
|
||||
split_size = local_shape[0]
|
||||
part_size = local_shape[gather_dim] # store original size
|
||||
ctx.part_size = part_size
|
||||
|
||||
output = all_gather_tensor(local_tensor, group, async_op)
|
||||
return torch.cat(output.split(split_size, dim=0), dim=gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_output: Tensor) -> Any:
|
||||
if ctx.grad_scaler:
|
||||
grad_output = grad_output * ctx.sp_world_size
|
||||
return (None, grad_output.split(ctx.part_size,
|
||||
dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), None, None, None, None)
|
||||
|
||||
|
||||
def gather_outpus_and_unpad(x: Tensor,
|
||||
gather_dim: int,
|
||||
unpad_dim: int = None,
|
||||
padding_size: int = 0,
|
||||
grad_scaler: bool = True,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
if group == None:
|
||||
return x
|
||||
x = Gather.apply(group, x, gather_dim, grad_scaler)
|
||||
if unpad_dim is not None:
|
||||
assert isinstance(padding_size, int), 'padding size is not given or is not an integer'
|
||||
if padding_size == 0:
|
||||
return x
|
||||
x = _unpad_tensor(x, unpad_dim, padding_size)
|
||||
return x
|
||||
|
||||
|
||||
def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor,
|
||||
position_ids_rmpad: Optional[torch.Tensor] = None,
|
||||
sp_size: int = 1):
|
||||
"""
|
||||
Pad and slice input_ids to be divisible by sp_size
|
||||
Pad position_ids to be divisible by sp_size.
|
||||
|
||||
Note both input_ids_rmpad and position_ids_rmpad will be padded,
|
||||
but only input_ids will be sliced.
|
||||
|
||||
The is the utility of pre-forward for ulysses sequence parallelism
|
||||
|
||||
Args:
|
||||
input_ids_rmpad: shape of [bsz, seqlen]
|
||||
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
|
||||
sp_size (int): ulysses sequence parallelism size
|
||||
|
||||
Returns:
|
||||
torch.Tensor: padded and sliced input_ids
|
||||
torch.Tensor: padded and sliced position_ids
|
||||
int: pad size
|
||||
"""
|
||||
if position_ids_rmpad is not None:
|
||||
assert position_ids_rmpad.size(0) == 1
|
||||
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
|
||||
if sp_size <= 1:
|
||||
return input_ids_rmpad, position_ids_rmpad, 0
|
||||
_, total_seq_len = input_ids_rmpad.shape
|
||||
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
|
||||
if pad_size > 0:
|
||||
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
|
||||
if position_ids_rmpad is not None:
|
||||
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
|
||||
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
|
||||
# we don't need to slice position ids
|
||||
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
|
||||
return input_ids_rmpad, position_ids_rmpad, pad_size
|
||||
Reference in New Issue
Block a user