fix test dataloader shuffle bug

This commit is contained in:
PeterGriffinJin
2025-04-02 22:23:11 +00:00
parent 716cd73977
commit 7530318919

View File

@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['info_mask']
attention_mask = data.batch['info_mask'] if 'info_mask' in data.batch else data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
@@ -409,8 +409,8 @@ class RayPPOTrainer(object):
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=self.config.data.val_batch_size,
shuffle=True,
drop_last=True,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
print(f'Size of train dataloader: {len(self.train_dataloader)}')