fix grpo id bug
This commit is contained in:
@@ -739,8 +739,9 @@ class RayPPOTrainer(object):
|
||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||
|
||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
dtype=object)
|
||||
# batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
# dtype=object)
|
||||
batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy()
|
||||
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
|
||||
Reference in New Issue
Block a user