remove unused debug lines
This commit is contained in:
@@ -186,9 +186,6 @@ class SACPolicy(
|
|||||||
* ~batch["action_is_pad"][:, 0]
|
* ~batch["action_is_pad"][:, 0]
|
||||||
) # shape: [batch_size, horizon]
|
) # shape: [batch_size, horizon]
|
||||||
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
||||||
# td_target -= self.config.discount * self.temperature() * log_probs \
|
|
||||||
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
|
||||||
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
|
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
@@ -219,9 +216,7 @@ class SACPolicy(
|
|||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
# q_preds_min = torch.min(q_preds, axis=0)
|
# q_preds_min = torch.min(q_preds, axis=0)
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
|
|
||||||
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
|
|
||||||
# breakpoint()
|
|
||||||
actor_loss = (
|
actor_loss = (
|
||||||
-(min_q_preds - temperature * log_probs).mean()
|
-(min_q_preds - temperature * log_probs).mean()
|
||||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||||
|
|||||||
Reference in New Issue
Block a user