remove unnecessary codes
This commit is contained in:
@@ -524,11 +524,7 @@ class RayPPOTrainer(object):
|
||||
|
||||
# evaluate using reward_function
|
||||
# for certain reward function (e.g. sandbox), the generation can overlap with reward
|
||||
try:
|
||||
reward_tensor = self.val_reward_fn(test_batch)
|
||||
except:
|
||||
print(test_batch)
|
||||
exit()
|
||||
reward_tensor = self.val_reward_fn(test_batch)
|
||||
|
||||
reward_tensor_lst.append(reward_tensor)
|
||||
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
|
||||
@@ -740,12 +736,8 @@ class RayPPOTrainer(object):
|
||||
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||
except:
|
||||
print('############### here ###################')
|
||||
print(final_gen_batch_output)
|
||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||
|
||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
dtype=object)
|
||||
@@ -773,12 +765,8 @@ class RayPPOTrainer(object):
|
||||
if self.use_reference_policy:
|
||||
# compute reference log_prob
|
||||
with _timer('ref', timing_raw):
|
||||
try:
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
batch = batch.union(ref_log_prob)
|
||||
except:
|
||||
print('################## herehere ################')
|
||||
print(batch)
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
batch = batch.union(ref_log_prob)
|
||||
|
||||
# compute values
|
||||
if self.use_critic:
|
||||
|
||||
Reference in New Issue
Block a user