remove unused debug lines
This commit is contained in:
@@ -186,9 +186,6 @@ class SACPolicy(
|
||||
* ~batch["action_is_pad"][:, 0]
|
||||
) # shape: [batch_size, horizon]
|
||||
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
||||
# td_target -= self.config.discount * self.temperature() * log_probs \
|
||||
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
@@ -219,9 +216,7 @@ class SACPolicy(
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
# q_preds_min = torch.min(q_preds, axis=0)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
|
||||
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
|
||||
# breakpoint()
|
||||
|
||||
actor_loss = (
|
||||
-(min_q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||
|
||||
Reference in New Issue
Block a user