# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict import functools import json import math import itertools import os from contextlib import contextmanager from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name import torch import torch.nn as nn import torch.distributed as dist def init_fn(x: torch.nn.Module): if not torch.distributed.get_rank() == 0: x = x.to_empty(device=torch.cuda.current_device(), recurse=False) torch.cuda.empty_cache() return x def get_init_weight_context_manager(use_meta_tensor=True): from accelerate import init_empty_weights cpu_init_weights = lambda: torch.device('cpu') if use_meta_tensor: init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights else: init_context = cpu_init_weights return init_context # Copyright 2020-present the HuggingFace Inc. team. # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py def get_fsdp_wrap_policy(module, config=None, is_lora=False): """Get FSDP wrap policy for the module. Args: module: The module to get wrap policy for config: Configuration for wrap policy is_lora: Whether to enable lambda policy for LoRA modules """ if config is None: config = {} if config.get('disable', False): return None default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = config.get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap) min_num_params = config.get('min_num_params', 0) auto_wrap_policy = None policies = [] from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy # Add lambda policy for LoRA modules if is_lora is True if is_lora: def lambda_policy_fn(module): if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad): return True return False lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) policies.append(lambda_policy) if min_num_params > 0: size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) policies.append(size_policy) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: transformer_cls = get_module_class_from_name(module, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) transformer_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap, ) policies.append(transformer_policy) if len(policies) > 0: auto_wrap_policy = functools.partial(_or_policy, policies=policies) return auto_wrap_policy def offload_fsdp_grad(module): for _, param in module.named_parameters(): if param.grad is not None: param.grad = param.grad.to("cpu", non_blocking=True) torch.cuda.empty_cache() def load_fsdp_grad(module, device_id): for _, param in module.named_parameters(): if param.grad is not None: param.grad = param.grad.to(device_id, non_blocking=True) torch.cuda.empty_cache() def offload_fsdp_param_and_grad(module, offload_grad=False): for _, param in module.named_parameters(): if hasattr(param, "_local_shard"): param._local_shard = param._local_shard.to("cpu", non_blocking=True) param.data = param.data.to('cpu', non_blocking=True) if offload_grad and param.grad is not None: param.grad = param.grad.to("cpu", non_blocking=True) torch.cuda.empty_cache() def load_fsdp_param_and_grad(module, device_id, load_grad=False): for _, param in module.named_parameters(): if hasattr(param, "_local_shard"): param._local_shard = param._local_shard.to(device_id, non_blocking=True) param.data = param.data.to(device_id, non_blocking=True) if load_grad and param.grad is not None: param.grad = param.grad.to(device_id, non_blocking=True) torch.cuda.empty_cache() def offload_fsdp_optimizer(optimizer): for param_group in optimizer.param_groups: for param in param_group['params']: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to("cpu", non_blocking=True) torch.cuda.empty_cache() def load_fsdp_optimizer(optimizer, device_id): for param_group in optimizer.param_groups: for param in param_group['params']: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device_id, non_blocking=True) torch.cuda.empty_cache() @contextmanager def meta_device_init(): """ Create model parameters with meta device. Note buffers in model will still be initialized in default device (e.g., CPU), since the buffers can be non-persistent and filled with expected values that can NOT be captured in meta device. """ device = torch.device("meta") old_register_parameter = nn.Module.register_parameter registered = set() def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) # we will skip register shared parameters as it # is already registered previously if param is not None and param not in registered: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) registered.add(module._parameters[name]) try: nn.Module.register_parameter = register_empty_parameter yield finally: registered.clear() nn.Module.register_parameter = old_register_parameter def parallel_load_safetensors(filepath): """ Parallel load safetensors from huggingface checkpoint Huggingface checkpoint contains: - config.json: a json file for model configuration - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks Or (when model is small), - model.safetensors: a binary file for all parameters and buffers Each rank will own a part of model chunks and load them directly into GPU memory. """ from safetensors.torch import load_file safetensors2param = {} index_file = os.path.join(filepath, "model.safetensors.index.json") if os.path.exists(index_file): index = json.load(open(index_file, "rb")) for param_name, filename in index["weight_map"].items(): safetensors2param.setdefault(filename, []).append(param_name) else: # in this case, the model is small and we can load it all at once param_file = os.path.join(filepath, "model.safetensors") assert os.path.exists(param_file), f"Cannot find {param_file}" states = load_file(param_file) for param_name in states: safetensors2param.setdefault("model.safetensors", []).append(param_name) del states total_files = len(safetensors2param) ckpt_chunks = sorted(safetensors2param.keys()) world_size = dist.get_world_size() size = int(math.ceil(total_files / world_size)) ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] shard_states = {} device = torch.cuda.current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: file = os.path.join(filepath, file) states = load_file(file, device=device) # print(f"rank {rank} loading {file}...") shard_states.update(states) else: for file in files: for param_name in safetensors2param[file]: shard_states[param_name] = rank return shard_states def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]): """ Generate a function to initialize sub-modules in the `module` with `shard_states` from huggingface checkpoint. Args: module (torch.nn.Module): the global module to be initialized shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint Returns: init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` """ state2fqn = {} for name, state in itertools.chain(module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)): state2fqn.setdefault(state, []).append(name) # remove standalone parameters and buffers shared = {s for s, names in state2fqn.items() if len(names) > 1} materialized_states = {} @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" device = torch.cuda.current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer param = torch.empty_like(state.data, device=device) loaded = shard_states[param_name] if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)): # NOTE: loaded.dtype can be different with param.dtype param.data.copy_(loaded.data) dist.broadcast(param.data, src=dist.get_rank()) else: assert isinstance(loaded, int) # the rank that holds the state dist.broadcast(param.data, src=loaded) shard_states.pop(param_name) del loaded return param def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) for name, state in param_and_buffers: if not state.is_meta: continue is_param = name in sub_mod._parameters fqn = state2fqn[state].pop(0) # non-persistent buffers will not be saved in state dict, we can safely skip it if (not is_param) and fqn not in shard_states: if state.is_meta: raise RuntimeError( f"find a non-persistent buffer ({fqn}) initiated with device meta. " "Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.") continue # for shared parameter, we get it from the first time it is created if state in shared: if state not in materialized_states: materialized_states[state] = create_and_sync_state(fqn, state, is_param) else: if fqn in shard_states: shard_states.pop(fqn) materialize_state = materialized_states[state] # for not shared parameter, we create it directly else: materialize_state = create_and_sync_state(fqn, state, is_param) if is_param: sub_mod._parameters[name] = materialize_state else: sub_mod._buffers[name] = materialize_state if recurse: for module in sub_mod.children(): init_fn(module, recurse=True) # for debug # if len(shard_states) == 0: print("clear") return sub_mod return init_fn