52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
from megatron.core import parallel_state as mpu
|
|
|
|
from .sequence_parallel import pad_to_sequence_parallel
|
|
|
|
|
|
def compute_transformers_input_shapes(batches, meta_info):
|
|
from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
|
|
# pre-compute input shapes for each micro-batch at each pp stage
|
|
input_shapes = []
|
|
for model_inputs in batches:
|
|
input_ids = model_inputs['input_ids']
|
|
attention_mask = model_inputs['attention_mask']
|
|
input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
|
|
if meta_info['sequence_parallel']:
|
|
input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
|
|
# compute shapes for model_inputs
|
|
input_shapes.append(
|
|
torch.Size([
|
|
input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
|
|
]))
|
|
else:
|
|
# compute shapes for model_inputs
|
|
input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
|
|
return input_shapes
|
|
|
|
|
|
def make_batch_generator(batches, vpp_size):
|
|
if vpp_size > 1:
|
|
# has vpp
|
|
batch_generator = [batches] * vpp_size # number of vpp chunks
|
|
batch_generator = [iter(b) for b in batch_generator]
|
|
else:
|
|
# no vpp
|
|
batch_generator = iter(batches)
|
|
return batch_generator
|