Initial commit
This commit is contained in:
16
verl/utils/dataset/README.md
Normal file
16
verl/utils/dataset/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Dataset Format
|
||||
## RLHF dataset
|
||||
We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.
|
||||
|
||||
Math problems
|
||||
```json
|
||||
{
|
||||
"data_source": "openai/gsm8k",
|
||||
"prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}],
|
||||
"ability": "math",
|
||||
"reward_model": {
|
||||
"style": "rule",
|
||||
"ground_truth": ["72"]
|
||||
},
|
||||
}
|
||||
```
|
||||
16
verl/utils/dataset/__init__.py
Normal file
16
verl/utils/dataset/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .rl_dataset import RLHFDataset
|
||||
from .rm_dataset import RMDataset
|
||||
155
verl/utils/dataset/rl_dataset.py
Normal file
155
verl/utils/dataset/rl_dataset.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from omegaconf import ListConfig
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def collate_fn(data_list: list[dict]) -> dict:
|
||||
tensors = {}
|
||||
non_tensors = {}
|
||||
|
||||
for data in data_list:
|
||||
for key, val in data.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
if key not in tensors:
|
||||
tensors[key] = []
|
||||
tensors[key].append(val)
|
||||
else:
|
||||
if key not in non_tensors:
|
||||
non_tensors[key] = []
|
||||
non_tensors[key].append(val)
|
||||
|
||||
for key, val in tensors.items():
|
||||
tensors[key] = torch.stack(val, dim=0)
|
||||
|
||||
for key, val in non_tensors.items():
|
||||
non_tensors[key] = np.array(val, dtype=object)
|
||||
|
||||
output = {}
|
||||
output.update(tensors)
|
||||
output.update(non_tensors)
|
||||
return output
|
||||
|
||||
|
||||
class RLHFDataset(Dataset):
|
||||
"""
|
||||
We assume the dataset contains a column that contains prompts and other information
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key='prompt',
|
||||
max_prompt_length=1024,
|
||||
filter_prompts=True,
|
||||
cache_dir='~/.cache/verl/rlhf',
|
||||
chat_template_func=None,
|
||||
return_raw_chat=False,
|
||||
truncation='error'):
|
||||
if not isinstance(parquet_files, (List, ListConfig)):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.filter_prompts = filter_prompts
|
||||
|
||||
self.return_raw_chat = return_raw_chat
|
||||
self.chat_template_func = chat_template_func
|
||||
self.truncation = truncation
|
||||
|
||||
self._download()
|
||||
self._read_files_and_tokenize()
|
||||
|
||||
def _download(self):
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
for i, parquet_file in enumerate(self.parquet_files):
|
||||
self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
# read parquet files and cache
|
||||
dataframe = pd.read_parquet(parquet_file)
|
||||
dataframes.append(dataframe)
|
||||
self.dataframe = pd.concat(dataframes)
|
||||
|
||||
print(f'original dataset len: {len(self.dataframe)}')
|
||||
|
||||
# filter out too long prompts
|
||||
tokenizer = self.tokenizer
|
||||
prompt_key = self.prompt_key
|
||||
|
||||
# nvm if prompt is too long
|
||||
# self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
|
||||
# tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
|
||||
# axis=1)]
|
||||
|
||||
print(f'filter dataset len: {len(self.dataframe)}')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
Note that we also return the raw_input_ids so that it can be combined with other chat template
|
||||
"""
|
||||
row_dict = self.dataframe.iloc[item].to_dict()
|
||||
|
||||
chat = row_dict.pop(self.prompt_key)
|
||||
|
||||
if self.tokenizer.chat_template:
|
||||
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
|
||||
else:
|
||||
prompt_with_chat_template = chat[0]['content']
|
||||
# prompt_with_chat_template = chat
|
||||
|
||||
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation)
|
||||
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
row_dict['input_ids'] = input_ids[0]
|
||||
row_dict['attention_mask'] = attention_mask[0]
|
||||
row_dict['position_ids'] = position_ids[0]
|
||||
|
||||
# encode prompts without chat template
|
||||
if self.return_raw_chat:
|
||||
row_dict['raw_prompt'] = chat.tolist()
|
||||
|
||||
# add index for each prompt
|
||||
index = row_dict.get("extra_info", {}).get("index", 0)
|
||||
row_dict["index"] = index
|
||||
|
||||
return row_dict
|
||||
143
verl/utils/dataset/rm_dataset.py
Normal file
143
verl/utils/dataset/rm_dataset.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
|
||||
def download_files_distributed(download_fn):
|
||||
import torch.distributed
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0:
|
||||
# download files
|
||||
download_fn()
|
||||
|
||||
torch.distributed.barrier()
|
||||
else:
|
||||
# download anyway
|
||||
download_fn()
|
||||
|
||||
|
||||
class RMDataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer,
|
||||
prompt_key='prompt',
|
||||
chosen_key='chosen',
|
||||
rejected_key='rejected',
|
||||
max_length=1024,
|
||||
add_eos=True,
|
||||
cache_dir='~/.cache/verl/rm'):
|
||||
if not isinstance(parquet_files, List):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
if isinstance(tokenizer, str):
|
||||
tokenizer = hf_tokenizer(tokenizer)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.chosen_key = chosen_key
|
||||
self.rejected_key = rejected_key
|
||||
|
||||
self.add_eos = add_eos
|
||||
self.max_length = max_length
|
||||
|
||||
self._download()
|
||||
self._read_files_and_tokenize()
|
||||
|
||||
def _download(self):
|
||||
|
||||
def _download_files():
|
||||
from verl.utils.fs import copy, _is_non_local
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
assert os.path.exists(self.cache_dir)
|
||||
for i, parquet_file in enumerate(self.parquet_files):
|
||||
if _is_non_local(parquet_file):
|
||||
dst = os.path.join(self.cache_dir, os.path.basename(parquet_file))
|
||||
if not os.path.exists(dst):
|
||||
copy(src=parquet_file, dst=dst)
|
||||
self.parquet_files[i] = dst
|
||||
|
||||
download_files_distributed(_download_files)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
# read parquet files and cache
|
||||
dataframe = pd.read_parquet(parquet_file)
|
||||
dataframes.append(dataframe)
|
||||
self.dataframe = pd.concat(dataframes)
|
||||
self.prompts = self.dataframe[self.prompt_key].tolist()
|
||||
self.chosen_responses = self.dataframe[self.chosen_key].tolist()
|
||||
self.rejected_responses = self.dataframe[self.rejected_key].tolist()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def _pad_to_length(self, input_ids, attention_mask):
|
||||
curr_length = input_ids.shape[-1]
|
||||
|
||||
if curr_length < self.max_length:
|
||||
input_ids = torch.cat(
|
||||
(input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1)
|
||||
attention_mask = torch.cat(
|
||||
(attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)),
|
||||
dim=-1)
|
||||
elif curr_length > self.max_length:
|
||||
input_ids = input_ids[:self.max_length]
|
||||
attention_mask = attention_mask[:self.max_length]
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
def __getitem__(self, item):
|
||||
prompt = self.prompts[item]
|
||||
chosen_response = self.chosen_responses[item]
|
||||
rejected_response = self.rejected_responses[item]
|
||||
|
||||
prompt_ids = self.tokenizer(prompt, return_tensors='pt')['input_ids'][0]
|
||||
chosen_response_ids = self.tokenizer(chosen_response, return_tensors='pt')['input_ids'][0]
|
||||
rejected_response_ids = self.tokenizer(rejected_response, return_tensors='pt')['input_ids'][0]
|
||||
|
||||
if self.add_eos:
|
||||
chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1)
|
||||
rejected_response_ids = torch.cat((rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])),
|
||||
dim=-1)
|
||||
|
||||
chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1)
|
||||
chosen_attention_mask = torch.ones_like(chosen_input_ids)
|
||||
|
||||
rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1)
|
||||
rejected_attention_mask = torch.ones_like(rejected_input_ids)
|
||||
|
||||
chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask)
|
||||
rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask)
|
||||
|
||||
input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0)
|
||||
attention_mask = torch.stack((rejected_input_ids, rejected_attention_mask), dim=0)
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
}
|
||||
Reference in New Issue
Block a user