remove unnecessary codes

This commit is contained in:
PeterGriffinJin
2025-03-17 16:08:33 +00:00
parent 118c6e7361
commit e85506f143

View File

@@ -524,11 +524,7 @@ class RayPPOTrainer(object):
# evaluate using reward_function # evaluate using reward_function
# for certain reward function (e.g. sandbox), the generation can overlap with reward # for certain reward function (e.g. sandbox), the generation can overlap with reward
try: reward_tensor = self.val_reward_fn(test_batch)
reward_tensor = self.val_reward_fn(test_batch)
except:
print(test_batch)
exit()
reward_tensor_lst.append(reward_tensor) reward_tensor_lst.append(reward_tensor)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
@@ -740,12 +736,8 @@ class RayPPOTrainer(object):
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long() final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
with torch.no_grad(): with torch.no_grad():
try: output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output) final_gen_batch_output = final_gen_batch_output.union(output)
final_gen_batch_output = final_gen_batch_output.union(output)
except:
print('############### here ###################')
print(final_gen_batch_output)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object) dtype=object)
@@ -773,12 +765,8 @@ class RayPPOTrainer(object):
if self.use_reference_policy: if self.use_reference_policy:
# compute reference log_prob # compute reference log_prob
with _timer('ref', timing_raw): with _timer('ref', timing_raw):
try: ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob)
batch = batch.union(ref_log_prob)
except:
print('################## herehere ################')
print(batch)
# compute values # compute values
if self.use_critic: if self.use_critic: