436 lines
19 KiB
Python
436 lines
19 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
A lightweight one-file FSDP SFT Trainer
|
|
TODO(zhangchi.usc1992)
|
|
- Add calculation of mfu
|
|
- Add validation
|
|
"""
|
|
|
|
import os
|
|
|
|
os.environ['NCCL_DEBUG'] = 'WARN'
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
|
|
import logging
|
|
import re
|
|
import torch
|
|
import torch.distributed
|
|
from torch import nn, optim
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
|
|
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
|
|
from tensordict import TensorDict
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
|
|
from verl.utils.dataset import SFTDataset
|
|
from verl.utils.fs import copy_local_path_from_hdfs
|
|
from verl.utils.tracking import Tracking
|
|
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
|
import verl.utils.hdfs_io as hdfs_io
|
|
from verl.utils.debug import log_gpu_memory_usage
|
|
from peft import LoraConfig, TaskType, get_peft_model
|
|
|
|
logger = logging.getLogger(__file__)
|
|
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
|
|
|
|
|
|
def extract_step(path):
|
|
match = re.search(r'global_step_(\d+)', path)
|
|
if match:
|
|
return int(match.group(1))
|
|
return None
|
|
|
|
|
|
def convert_to_regular_types(obj):
|
|
"""Convert Hydra configs and other special types to regular Python types."""
|
|
from omegaconf import ListConfig, DictConfig
|
|
if isinstance(obj, (ListConfig, DictConfig)):
|
|
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
|
|
elif isinstance(obj, (list, tuple)):
|
|
return [convert_to_regular_types(x) for x in obj]
|
|
elif isinstance(obj, dict):
|
|
return {k: convert_to_regular_types(v) for k, v in obj.items()}
|
|
return obj
|
|
|
|
|
|
class FSDPSFTTrainer(object):
|
|
|
|
def __init__(self, config, device_mesh: DeviceMesh):
|
|
self.config = config
|
|
self.device_mesh = device_mesh
|
|
# build tokenizer first
|
|
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
|
|
from verl.utils import hf_tokenizer
|
|
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
|
|
if self.config.data.chat_template is not None:
|
|
raise ValueError('Apply Chat template from config is not supported yet.')
|
|
|
|
# normalize dp size
|
|
self._normalize_config_bsz()
|
|
|
|
self._build_dataloader()
|
|
# build model
|
|
self._build_model_optimizer()
|
|
|
|
# TODO: add checkpoint manager
|
|
if self.device_mesh.get_rank() == 0:
|
|
print(self.config)
|
|
|
|
def _normalize_config_bsz(self):
|
|
dp_size = self.device_mesh.size()
|
|
if self.device_mesh.get_rank() == 0:
|
|
print(f'Normalize batch size by dp {dp_size}')
|
|
|
|
assert self.config.data.train_batch_size % dp_size == 0
|
|
assert self.config.data.micro_batch_size % dp_size == 0
|
|
|
|
self.config.data.train_batch_size //= dp_size
|
|
self.config.data.micro_batch_size //= dp_size
|
|
|
|
def _build_dataloader(self):
|
|
config = self.config
|
|
# build dataset
|
|
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
|
|
tokenizer=self.tokenizer,
|
|
prompt_key=config.data.prompt_key,
|
|
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
|
|
response_key=config.data.response_key,
|
|
response_dict_keys=config.data.get('response_dict_keys', None),
|
|
max_length=config.data.max_length,
|
|
truncation=config.data.truncation)
|
|
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
|
|
tokenizer=self.tokenizer,
|
|
prompt_key=config.data.prompt_key,
|
|
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
|
|
response_key=config.data.response_key,
|
|
response_dict_keys=config.data.get('response_dict_keys', None),
|
|
max_length=config.data.max_length,
|
|
truncation=config.data.truncation)
|
|
|
|
# build dataloader
|
|
rank = self.device_mesh.get_rank()
|
|
world_size = self.device_mesh.size()
|
|
self.train_sampler = DistributedSampler(self.train_dataset,
|
|
shuffle=True,
|
|
num_replicas=world_size,
|
|
rank=rank,
|
|
drop_last=True)
|
|
self.train_dataloader = DataLoader(dataset=self.train_dataset,
|
|
batch_size=config.data.train_batch_size,
|
|
sampler=self.train_sampler,
|
|
num_workers=8,
|
|
pin_memory=True,
|
|
drop_last=True)
|
|
|
|
self.val_sampler = DistributedSampler(self.val_dataset,
|
|
shuffle=True,
|
|
num_replicas=world_size,
|
|
rank=rank,
|
|
drop_last=True)
|
|
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
|
batch_size=config.data.micro_batch_size,
|
|
sampler=self.val_sampler,
|
|
num_workers=8,
|
|
pin_memory=True,
|
|
drop_last=True)
|
|
|
|
def _build_model_optimizer(self):
|
|
# TODO (zhangchi.usc1992):
|
|
# 1. support pretrain from random weights
|
|
# 2. support init directly from sharded weights
|
|
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
|
|
|
|
if self.config.model.get('external_lib', None) is not None:
|
|
# This is used to import external_lib into the huggingface systems
|
|
import importlib
|
|
importlib.import_module(self.config.model.external_lib)
|
|
|
|
log_gpu_memory_usage('Before model allocation', logger=logger)
|
|
|
|
trust_remote_code = self.config.model.trust_remote_code
|
|
# load config first
|
|
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
|
|
|
|
# This may be very large
|
|
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings)
|
|
|
|
with init_context():
|
|
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
|
|
config=config,
|
|
torch_dtype=torch.float32,
|
|
attn_implementation='flash_attention_2',
|
|
trust_remote_code=trust_remote_code)
|
|
if self.config.model.get('lora_rank', 0) > 0:
|
|
self.model.enable_input_require_grads()
|
|
# Convert config to regular Python types before creating PEFT model
|
|
lora_config = {
|
|
'task_type': TaskType.CAUSAL_LM,
|
|
'r': self.config.model.lora_rank,
|
|
'lora_alpha': self.config.model.lora_alpha,
|
|
'target_modules': convert_to_regular_types(self.config.model.target_modules),
|
|
'bias': "none"
|
|
}
|
|
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
|
|
|
|
if self.config.model.enable_gradient_checkpointing:
|
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
|
|
|
|
log_gpu_memory_usage('After model allocation', logger=logger)
|
|
|
|
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.float32,
|
|
buffer_dtype=torch.float32)
|
|
|
|
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
|
|
config=self.config.model.fsdp_config.wrap_policy,
|
|
is_lora=self.config.model.get('lora_rank', 0) > 0)
|
|
if self.device_mesh.get_rank() == 0:
|
|
print(auto_wrap_policy)
|
|
|
|
if not self.config.model.fsdp_config.cpu_offload:
|
|
cpu_offload = None
|
|
else:
|
|
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
|
|
|
|
self.fsdp_model = FSDP(module=self.model,
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
param_init_fn=init_fn,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
mixed_precision=mixed_precision,
|
|
device_mesh=self.device_mesh,
|
|
sync_module_states=True,
|
|
device_id=torch.cuda.current_device(),
|
|
cpu_offload=cpu_offload,
|
|
use_orig_params=False)
|
|
|
|
log_gpu_memory_usage('After FSDP wrapping', logger=logger)
|
|
|
|
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
|
|
lr=self.config.optim.lr,
|
|
betas=self.config.optim.betas,
|
|
weight_decay=self.config.optim.weight_decay)
|
|
|
|
log_gpu_memory_usage('After initialize optimizer', logger=logger)
|
|
|
|
steps_per_epoch = len(self.train_dataloader)
|
|
total_steps = steps_per_epoch * self.config.trainer.total_epochs
|
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
print(
|
|
f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}'
|
|
)
|
|
|
|
num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio)
|
|
|
|
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_training_steps=total_steps)
|
|
|
|
def _compute_loss(self, batch):
|
|
loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda()
|
|
labels = batch['input_ids'][:, 1:].cuda()
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
|
output = self.fsdp_model(input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
position_ids=batch['position_ids'],
|
|
use_cache=False) # prevent model thinks it it generating
|
|
|
|
logits = output.logits
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels.contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
|
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
|
|
shift_labels = shift_labels.view(-1)
|
|
# Enable model parallelism
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
loss = loss * loss_mask
|
|
|
|
valid_token_this_rank = torch.sum(loss_mask)
|
|
|
|
if self.config.data.balance_dp_token:
|
|
torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks
|
|
dp_size = torch.distributed.get_world_size()
|
|
else:
|
|
dp_size = 1
|
|
|
|
loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp
|
|
return loss
|
|
|
|
def training_step(self, batch: TensorDict):
|
|
self.fsdp_model.train()
|
|
|
|
log_gpu_memory_usage('Before optimizer zero_grad', logger=logger)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
log_gpu_memory_usage('After optimizer zero_grad', logger=logger)
|
|
|
|
micro_batches = batch.split(self.config.data.micro_batch_size)
|
|
n_micro_batches = len(micro_batches)
|
|
step_loss = 0
|
|
for micro_batch in micro_batches:
|
|
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
|
|
loss.backward()
|
|
step_loss += loss.item()
|
|
|
|
self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
|
|
|
|
log_gpu_memory_usage('Before optimizer step', logger=logger)
|
|
|
|
self.optimizer.step()
|
|
|
|
log_gpu_memory_usage('After optimizer step', logger=logger)
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
# reduce loss across dp ranks
|
|
lr = self.lr_scheduler.get_last_lr()[0]
|
|
|
|
log_gpu_memory_usage('After offload weights', logger=logger)
|
|
|
|
step_loss = torch.tensor(step_loss).cuda()
|
|
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
|
|
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
|
|
|
|
def validation_step(self, batch: TensorDict):
|
|
self.fsdp_model.eval()
|
|
with torch.no_grad():
|
|
loss = self._compute_loss(batch)
|
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
|
|
return loss
|
|
|
|
def save_checkpoint(self, step):
|
|
# save checkpoint
|
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
|
|
state_dict = self.fsdp_model.state_dict()
|
|
|
|
path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
|
|
# save huggingface model
|
|
if self.device_mesh.get_rank() == 0:
|
|
os.makedirs(path, exist_ok=True)
|
|
self.model.save_pretrained(path, state_dict=state_dict)
|
|
self.tokenizer.save_pretrained(path)
|
|
if self.config.trainer.default_hdfs_dir:
|
|
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
|
|
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
|
|
torch.distributed.barrier()
|
|
|
|
def fit(self):
|
|
rank = self.device_mesh.get_rank()
|
|
|
|
# TODO: add a unified tracking
|
|
if rank == 0:
|
|
tracking = Tracking(project_name=self.config.trainer.project_name,
|
|
experiment_name=self.config.trainer.experiment_name,
|
|
default_backend=self.config.trainer.logger)
|
|
|
|
global_step = 0
|
|
# compute the total training steps.
|
|
# the total training steps in SFT is mainly for early exit
|
|
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
|
|
|
if self.config.trainer.total_training_steps is not None:
|
|
total_training_steps = self.config.trainer.total_training_steps
|
|
|
|
self.total_training_steps = total_training_steps
|
|
print(f'Total training steps: {self.total_training_steps}')
|
|
|
|
# TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow.
|
|
|
|
if self.config.trainer.validate_before_training:
|
|
# validate before training
|
|
val_losses = []
|
|
for data in self.val_dataloader:
|
|
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
|
|
val_loss = self.validation_step(data)
|
|
val_losses.append(val_loss)
|
|
if rank == 0:
|
|
val_loss = torch.mean(torch.stack(val_losses))
|
|
metric = {'val/loss': val_loss.detach().item()}
|
|
tracking.log(data=metric, step=global_step)
|
|
torch.distributed.barrier()
|
|
|
|
for epoch in range(self.config.trainer.total_epochs):
|
|
self.train_sampler.set_epoch(epoch=epoch)
|
|
for data in self.train_dataloader:
|
|
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
|
|
metric = self.training_step(data)
|
|
if rank == 0:
|
|
tracking.log(data=metric, step=global_step)
|
|
global_step += 1
|
|
|
|
# for early exit validation
|
|
if global_step >= self.total_training_steps:
|
|
# Perform final validation
|
|
val_losses = []
|
|
for val_data in self.val_dataloader:
|
|
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size).cuda()
|
|
val_loss = self.validation_step(val_data)
|
|
val_losses.append(val_loss)
|
|
if rank == 0:
|
|
avg_val_loss = torch.mean(torch.stack(val_losses))
|
|
metric = {'val/loss': avg_val_loss.detach().item()}
|
|
tracking.log(data=metric, step=global_step)
|
|
torch.distributed.barrier()
|
|
|
|
# Save final checkpoint
|
|
self.save_checkpoint(step=global_step)
|
|
return
|
|
|
|
# validation
|
|
val_losses = []
|
|
for data in self.val_dataloader:
|
|
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
|
|
val_loss = self.validation_step(data)
|
|
val_losses.append(val_loss)
|
|
if rank == 0:
|
|
val_loss = torch.mean(torch.stack(val_losses))
|
|
metric = {'val/loss': val_loss.detach().item()}
|
|
tracking.log(data=metric, step=global_step)
|
|
torch.distributed.barrier()
|
|
|
|
# save checkpoint
|
|
self.save_checkpoint(step=global_step)
|
|
|
|
|
|
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
|
|
import hydra
|
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
from verl.utils.distributed import initialize_global_process_group
|
|
|
|
|
|
@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
|
|
def main(config):
|
|
local_rank, rank, world_size = initialize_global_process_group()
|
|
|
|
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
|
|
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
|
|
trainer.fit()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|