fix test dataloader shuffle bug
This commit is contained in:
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
|||||||
response_length = responses.size(1)
|
response_length = responses.size(1)
|
||||||
token_level_scores = data.batch['token_level_scores']
|
token_level_scores = data.batch['token_level_scores']
|
||||||
batch_size = data.batch.batch_size[0]
|
batch_size = data.batch.batch_size[0]
|
||||||
attention_mask = data.batch['info_mask']
|
attention_mask = data.batch['info_mask'] if 'info_mask' in data.batch else data.batch['attention_mask']
|
||||||
response_mask = attention_mask[:, -response_length:]
|
response_mask = attention_mask[:, -response_length:]
|
||||||
|
|
||||||
# compute kl between ref_policy and current policy
|
# compute kl between ref_policy and current policy
|
||||||
@@ -409,8 +409,8 @@ class RayPPOTrainer(object):
|
|||||||
|
|
||||||
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
||||||
batch_size=self.config.data.val_batch_size,
|
batch_size=self.config.data.val_batch_size,
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
drop_last=True,
|
drop_last=False,
|
||||||
collate_fn=collate_fn)
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
print(f'Size of train dataloader: {len(self.train_dataloader)}')
|
print(f'Size of train dataloader: {len(self.train_dataloader)}')
|
||||||
|
|||||||
Reference in New Issue
Block a user