remove unnecessary codes

This commit is contained in:
PeterGriffinJin
2025-03-17 16:08:33 +00:00
parent 118c6e7361
commit e85506f143

View File

@@ -524,11 +524,7 @@ class RayPPOTrainer(object):
# evaluate using reward_function
# for certain reward function (e.g. sandbox), the generation can overlap with reward
try:
reward_tensor = self.val_reward_fn(test_batch)
except:
print(test_batch)
exit()
reward_tensor = self.val_reward_fn(test_batch)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
@@ -740,12 +736,8 @@ class RayPPOTrainer(object):
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
with torch.no_grad():
try:
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
final_gen_batch_output = final_gen_batch_output.union(output)
except:
print('############### here ###################')
print(final_gen_batch_output)
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
final_gen_batch_output = final_gen_batch_output.union(output)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
@@ -773,12 +765,8 @@ class RayPPOTrainer(object):
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
try:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
except:
print('################## herehere ################')
print(batch)
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic: