remove unused debug lines

This commit is contained in:
Ke-Wang1017
2025-01-06 10:18:40 +00:00
parent 8b70b129dc
commit 89d8189d8b

View File

@@ -186,9 +186,6 @@ class SACPolicy(
* ~batch["action_is_pad"][:, 0]
) # shape: [batch_size, horizon]
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
# td_target -= self.config.discount * self.temperature() * log_probs \
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
@@ -219,9 +216,7 @@ class SACPolicy(
q_preds = self.critic_forward(observations, actions, use_target=False)
# q_preds_min = torch.min(q_preds, axis=0)
min_q_preds = q_preds.min(dim=0)[0]
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
# breakpoint()
actor_loss = (
-(min_q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]