fix test dataloader shuffle bug
This commit is contained in:
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
||||
response_length = responses.size(1)
|
||||
token_level_scores = data.batch['token_level_scores']
|
||||
batch_size = data.batch.batch_size[0]
|
||||
attention_mask = data.batch['info_mask']
|
||||
attention_mask = data.batch['info_mask'] if 'info_mask' in data.batch else data.batch['attention_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
# compute kl between ref_policy and current policy
|
||||
@@ -409,8 +409,8 @@ class RayPPOTrainer(object):
|
||||
|
||||
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
||||
batch_size=self.config.data.val_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
print(f'Size of train dataloader: {len(self.train_dataloader)}')
|
||||
|
||||
Reference in New Issue
Block a user