fix grpo id bug

This commit is contained in:
PeterGriffinJin
2025-03-19 18:59:19 +00:00
parent 8c7f04ca45
commit 9ec2fa9892
2 changed files with 7 additions and 6 deletions

View File

@@ -739,8 +739,9 @@ class RayPPOTrainer(object):
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
final_gen_batch_output = final_gen_batch_output.union(output)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
# dtype=object)
batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy()
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)