# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import logging import torch from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig from torch.distributed.device_mesh import DeviceMesh from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps from verl import DataProto from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.debug import log_gpu_memory_usage from .base import BaseShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) class FSDPVLLMShardingManager(BaseShardingManager): def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False, device_mesh: DeviceMesh = None): self.module = module self.inference_engine = inference_engine self.model_config = model_config self.device_mesh = device_mesh # Full params self.full_params = full_params if full_params: FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) else: FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()) # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh['dp'].get_local_rank() torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) else: self.gen_random_states = None def __enter__(self): log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger) params = self.module.state_dict() log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) # Copy, not share memory load_format = 'hf' if self.full_params else 'dtensor' self.inference_engine.sync_model_weights(params, load_format=load_format) log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) del params torch.cuda.empty_cache() log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) # TODO: offload FSDP model weights # self.module.cpu() # torch.cuda.empty_cache() # if torch.distributed.get_rank() == 0: # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) self.inference_engine.offload_model_weights() log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) # self.module.to('cuda') # if torch.distributed.get_rank() == 0: # print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') self.module.train() # add empty cache after each compute torch.cuda.empty_cache() # restore random states if self.device_mesh is not None: self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp data.batch = allgather_dict_tensors(data.batch.contiguous(), size=vllm_ps.get_tensor_model_parallel_world_size(), group=vllm_ps.get_tensor_model_parallel_group(), dim=0) return data def postprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp broadcast_dict_tensor(data.batch, src=vllm_ps.get_tensor_model_parallel_src_rank(), group=vllm_ps.get_tensor_model_parallel_group()) dp_rank = torch.distributed.get_rank() dp_size = torch.distributed.get_world_size() # not consider torch micro-dp tp_size = vllm_ps.get_tensor_model_parallel_world_size() if tp_size > 1: # TODO: shall we build a micro_dp group for vllm when integrating with vLLM? local_prompts = data.chunk(chunks=tp_size) data = local_prompts[dp_rank % tp_size] return data