remove unnecessary codes
This commit is contained in:
@@ -524,11 +524,7 @@ class RayPPOTrainer(object):
|
|||||||
|
|
||||||
# evaluate using reward_function
|
# evaluate using reward_function
|
||||||
# for certain reward function (e.g. sandbox), the generation can overlap with reward
|
# for certain reward function (e.g. sandbox), the generation can overlap with reward
|
||||||
try:
|
reward_tensor = self.val_reward_fn(test_batch)
|
||||||
reward_tensor = self.val_reward_fn(test_batch)
|
|
||||||
except:
|
|
||||||
print(test_batch)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
reward_tensor_lst.append(reward_tensor)
|
reward_tensor_lst.append(reward_tensor)
|
||||||
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
|
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
|
||||||
@@ -740,12 +736,8 @@ class RayPPOTrainer(object):
|
|||||||
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
|
final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
try:
|
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
|
||||||
except:
|
|
||||||
print('############### here ###################')
|
|
||||||
print(final_gen_batch_output)
|
|
||||||
|
|
||||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||||
dtype=object)
|
dtype=object)
|
||||||
@@ -773,12 +765,8 @@ class RayPPOTrainer(object):
|
|||||||
if self.use_reference_policy:
|
if self.use_reference_policy:
|
||||||
# compute reference log_prob
|
# compute reference log_prob
|
||||||
with _timer('ref', timing_raw):
|
with _timer('ref', timing_raw):
|
||||||
try:
|
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
batch = batch.union(ref_log_prob)
|
||||||
batch = batch.union(ref_log_prob)
|
|
||||||
except:
|
|
||||||
print('################## herehere ################')
|
|
||||||
print(batch)
|
|
||||||
|
|
||||||
# compute values
|
# compute values
|
||||||
if self.use_critic:
|
if self.use_critic:
|
||||||
|
|||||||
Reference in New Issue
Block a user