diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 159602e2c..318a8cd22 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -54,7 +54,7 @@ class SACConfig: critic_target_update_weight = 0.005 utd_ratio = 2 state_encoder_hidden_dim = 256 - latent_dim = 128 + latent_dim = 256 target_entropy = None backup_entropy = True critic_network_kwargs = { diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 71083b57c..599308a5a 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -111,7 +111,7 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _ = self.actor(batch) + _, _, actions = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions @@ -155,7 +155,7 @@ class SACPolicy( # calculate critics loss # 1- compute actions from policy - action_preds, log_probs = self.actor(next_observations) + action_preds, log_probs, _ = self.actor(next_observations) # 2- compute q targets q_targets = self.critic_forward(next_observations, action_preds, use_target=True) @@ -195,7 +195,7 @@ class SACPolicy( # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor(observations) + actions, log_probs, _ = self.actor(observations) # 3- get q-value predictions with torch.inference_mode(): q_preds = self.critic_forward(observations, actions, use_target=False) @@ -405,8 +405,9 @@ class Policy(nn.Module): actions = torch.tanh(x_t) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs = log_probs.sum(-1) # sum over action dim + means = torch.tanh(means) - return actions, log_probs + return actions, log_probs, means def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index dad1508bf..97f040aaf 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -15,7 +15,7 @@ training: # Offline training dataloader num_workers: 4 - batch_size: 128 + batch_size: 256 grad_clip_norm: 10.0 lr: 3e-4