From 89d8189d8bd1e4996ce5c6da0968e0863c46cbfc Mon Sep 17 00:00:00 2001 From: Ke-Wang1017 Date: Mon, 6 Jan 2025 10:18:40 +0000 Subject: [PATCH] remove unused debug lines --- lerobot/common/policies/sac/modeling_sac.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 84860b35..ff021956 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -186,9 +186,6 @@ class SACPolicy( * ~batch["action_is_pad"][:, 0] ) # shape: [batch_size, horizon] td_target = rewards + self.config.discount * min_q * ~batch["next.done"] - # td_target -= self.config.discount * self.temperature() * log_probs \ - # * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] - # print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}") # 3- compute predicted qs q_preds = self.critic_forward(observations, actions, use_target=False) @@ -219,9 +216,7 @@ class SACPolicy( q_preds = self.critic_forward(observations, actions, use_target=False) # q_preds_min = torch.min(q_preds, axis=0) min_q_preds = q_preds.min(dim=0)[0] - # print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}") - # print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}") - # breakpoint() + actor_loss = ( -(min_q_preds - temperature * log_probs).mean() * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]