# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generate responses given a dataset of prompts """ import ray import numpy as np import hydra import os os.environ['NCCL_DEBUG'] = 'WARN' os.environ['TOKENIZERS_PARALLELISM'] = 'true' # os.environ['TORCH_COMPILE_DISABLE'] = '1' from verl.utils.model import compute_position_id_with_mask import pandas as pd from transformers import AutoTokenizer from verl import DataProto from verl.utils.fs import copy_local_path_from_hdfs from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.utils.hdfs_io import makedirs from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @hydra.main(config_path='config', config_name='generation', version_base=None) def main(config): from pprint import pprint from omegaconf import OmegaConf pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) local_path = copy_local_path_from_hdfs(config.model.path) from verl.utils import hf_tokenizer tokenizer = hf_tokenizer(local_path) if config.rollout.temperature == 0.: assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.' # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) dataset = pd.read_parquet(config.data.path) chat_lst = dataset[config.data.prompt_key].tolist() chat_lst = [chat.tolist() for chat in chat_lst] tokenizer.padding_side = 'left' if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout') resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) wg.init_model() total_samples = len(dataset) # real_batch_size = data.batch['input_ids'].shape[0] config_batch_size = config.data.batch_size dp_size = wg.world_size // config.rollout.tensor_model_parallel_size num_batch = (total_samples // config_batch_size) + 1 output_lst = [[] for _ in range(config.data.n_samples)] for batch_idx in range(num_batch): print(f'[{batch_idx+1}/{num_batch}] Start to process.') batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size] inputs = tokenizer.apply_chat_template(batch_chat_lst, add_generation_prompt=True, padding=True, truncation=True, max_length=config.rollout.prompt_length, return_tensors='pt', return_dict=True, tokenize=True) input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] position_ids = compute_position_id_with_mask(attention_mask) batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids} data = DataProto.from_dict(batch_dict) real_batch_size = data.batch['input_ids'].shape[0] if real_batch_size % dp_size != 0: dummy_data_size = dp_size - real_batch_size % dp_size dummy_data = data[:dummy_data_size] data = DataProto.concat([data, dummy_data]) print( f'dp_size {dp_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data' ) batch_size = data.batch['input_ids'].shape[0] assert batch_size % dp_size == 0, f'batch_size {batch_size} is not divisible by dp_size {dp_size}' print(f'[{batch_idx+1}/{num_batch}] Start to generate.') # START TO GENERATE FOR n_samples TIMES for i in range(config.data.n_samples): output = wg.generate_sequences(data) # remove dummy data output = output[:real_batch_size] output_text = tokenizer.batch_decode(output.batch['input_ids'][:, -config.rollout.response_length:], skip_special_tokens=False) # remove the padding pad_token = tokenizer.pad_token output_text_unpad = [] for text in output_text: output_text_unpad.append(text.replace(pad_token, '')) output_lst[i].extend(output_text_unpad) # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) output_lst = np.array(output_lst, dtype=object) output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() # add to the data frame dataset[f'responses'] = output_lst # write to a new parquet output_dir = os.path.dirname(config.data.output_path) makedirs(output_dir, exist_ok=True) dataset.to_parquet(config.data.output_path) return output_text if __name__ == '__main__': main()