Initial commit
This commit is contained in:
155
verl/utils/dataset/rl_dataset.py
Normal file
155
verl/utils/dataset/rl_dataset.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from omegaconf import ListConfig
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def collate_fn(data_list: list[dict]) -> dict:
|
||||
tensors = {}
|
||||
non_tensors = {}
|
||||
|
||||
for data in data_list:
|
||||
for key, val in data.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
if key not in tensors:
|
||||
tensors[key] = []
|
||||
tensors[key].append(val)
|
||||
else:
|
||||
if key not in non_tensors:
|
||||
non_tensors[key] = []
|
||||
non_tensors[key].append(val)
|
||||
|
||||
for key, val in tensors.items():
|
||||
tensors[key] = torch.stack(val, dim=0)
|
||||
|
||||
for key, val in non_tensors.items():
|
||||
non_tensors[key] = np.array(val, dtype=object)
|
||||
|
||||
output = {}
|
||||
output.update(tensors)
|
||||
output.update(non_tensors)
|
||||
return output
|
||||
|
||||
|
||||
class RLHFDataset(Dataset):
|
||||
"""
|
||||
We assume the dataset contains a column that contains prompts and other information
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key='prompt',
|
||||
max_prompt_length=1024,
|
||||
filter_prompts=True,
|
||||
cache_dir='~/.cache/verl/rlhf',
|
||||
chat_template_func=None,
|
||||
return_raw_chat=False,
|
||||
truncation='error'):
|
||||
if not isinstance(parquet_files, (List, ListConfig)):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.filter_prompts = filter_prompts
|
||||
|
||||
self.return_raw_chat = return_raw_chat
|
||||
self.chat_template_func = chat_template_func
|
||||
self.truncation = truncation
|
||||
|
||||
self._download()
|
||||
self._read_files_and_tokenize()
|
||||
|
||||
def _download(self):
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
for i, parquet_file in enumerate(self.parquet_files):
|
||||
self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
# read parquet files and cache
|
||||
dataframe = pd.read_parquet(parquet_file)
|
||||
dataframes.append(dataframe)
|
||||
self.dataframe = pd.concat(dataframes)
|
||||
|
||||
print(f'original dataset len: {len(self.dataframe)}')
|
||||
|
||||
# filter out too long prompts
|
||||
tokenizer = self.tokenizer
|
||||
prompt_key = self.prompt_key
|
||||
|
||||
# nvm if prompt is too long
|
||||
# self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
|
||||
# tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
|
||||
# axis=1)]
|
||||
|
||||
print(f'filter dataset len: {len(self.dataframe)}')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
Note that we also return the raw_input_ids so that it can be combined with other chat template
|
||||
"""
|
||||
row_dict = self.dataframe.iloc[item].to_dict()
|
||||
|
||||
chat = row_dict.pop(self.prompt_key)
|
||||
|
||||
if self.tokenizer.chat_template:
|
||||
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
|
||||
else:
|
||||
prompt_with_chat_template = chat[0]['content']
|
||||
# prompt_with_chat_template = chat
|
||||
|
||||
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation)
|
||||
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
row_dict['input_ids'] = input_ids[0]
|
||||
row_dict['attention_mask'] = attention_mask[0]
|
||||
row_dict['position_ids'] = position_ids[0]
|
||||
|
||||
# encode prompts without chat template
|
||||
if self.return_raw_chat:
|
||||
row_dict['raw_prompt'] = chat.tolist()
|
||||
|
||||
# add index for each prompt
|
||||
index = row_dict.get("extra_info", {}).get("index", 0)
|
||||
row_dict["index"] = index
|
||||
|
||||
return row_dict
|
||||
Reference in New Issue
Block a user