diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 318a8cd22..8ad07a733 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -56,7 +56,7 @@ class SACConfig: state_encoder_hidden_dim = 256 latent_dim = 256 target_entropy = None - backup_entropy = True + backup_entropy = False critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 599308a5a..297e9f53c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -90,7 +90,7 @@ class SACPolicy( **config.policy_kwargs ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) + config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2) self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): @@ -111,7 +111,7 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - _, _, actions = self.actor(batch) + actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions @@ -155,23 +155,28 @@ class SACPolicy( # calculate critics loss # 1- compute actions from policy - action_preds, log_probs, _ = self.actor(next_observations) + with torch.no_grad(): + action_preds, log_probs, _ = self.actor(next_observations) - # 2- compute q targets - q_targets = self.critic_forward(next_observations, action_preds, use_target=True) + # 2- compute q targets + q_targets = self.critic_forward(next_observations, action_preds, use_target=True) - # subsample critics to prevent overfitting if use high UTD (update to date) - if self.config.num_subsample_critics is not None: - indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] - q_targets = q_targets[indices] + # subsample critics to prevent overfitting if use high UTD (update to date) + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[:self.config.num_subsample_critics] + q_targets = q_targets[indices] - # critics subsample size - min_q, _ = q_targets.min(dim=0) # Get values from min operation - - # compute td target - td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + # compute td target + td_target = rewards + self.config.discount * min_q + if self.config.use_backup_entropy: + td_target -= self.config.discount * self.temperature() * log_probs \ + * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + # print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}") + # 3- compute predicted qs q_preds = self.critic_forward(observations, actions, use_target=False) @@ -197,10 +202,15 @@ class SACPolicy( # 2- get actions (batch_size, action_dim) and log probs (batch_size,) actions, log_probs, _ = self.actor(observations) # 3- get q-value predictions - with torch.inference_mode(): - q_preds = self.critic_forward(observations, actions, use_target=False) + # with torch.inference_mode(): + q_preds = self.critic_forward(observations, actions, use_target=False) + # q_preds_min = torch.min(q_preds, axis=0) + min_q_preds = q_preds.min(dim=0)[0] + # print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}") + # print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}") + # breakpoint() actor_loss = ( - -(q_preds - temperature * log_probs).mean() + -(min_q_preds - temperature * log_probs).mean() * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] ).mean() @@ -208,6 +218,8 @@ class SACPolicy( # calculate temperature loss # 1- calculate entropy + with torch.no_grad(): + actions, log_probs, _ = self.actor(observations) entropy = -log_probs.mean() temperature_loss = self.temperature( lhs=entropy, @@ -219,8 +231,17 @@ class SACPolicy( return { "critics_loss": critics_loss.item(), "actor_loss": actor_loss.item(), + "mean_q_predicts": min_q_preds.mean().item(), + "min_q_predicts":min_q_preds.min().item(), + "max_q_predicts":min_q_preds.max().item(), "temperature_loss": temperature_loss.item(), "temperature": temperature.item(), + "mean_log_probs": log_probs.mean().item(), + "min_log_probs": log_probs.min().item(), + "max_log_probs": log_probs.max().item(), + "td_target_mean": td_target.mean().item(), + "td_target_mean": td_target.max().item(), + "action_mean": actions.mean().item(), "entropy": entropy.item(), "loss": loss, } @@ -236,8 +257,8 @@ class SACPolicy( for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): target_param.data.copy_( - target_param.data * self.config.critic_target_update_weight + - param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) ) class MLP(nn.Module): @@ -391,15 +412,16 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) if self.use_tanh_squash: log_std = torch.tanh(log_std) - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: log_std = self.fixed_std.expand_as(means) # uses tahn activation function to squash the action to be in the range of [-1, 1] normal = torch.distributions.Normal(means, torch.exp(log_std)) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + x_t = torch.clamp(x_t, -2.0, 2.0) log_probs = normal.log_prob(x_t) if self.use_tanh_squash: actions = torch.tanh(x_t) @@ -456,19 +478,15 @@ class SACObservationEncoder(nn.Module): ) if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( - nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), - nn.ELU(), - nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim), nn.LayerNorm(config.latent_dim), nn.Tanh(), ) if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( nn.Linear( - config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + config.input_shapes["observation.environment_state"][0], config.latent_dim ), - nn.ELU(), - nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), nn.Tanh(), ) @@ -506,26 +524,27 @@ class LagrangeMultiplier(nn.Module): ): super().__init__() self.device = torch.device(device) - init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) + # init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) + init_value = torch.tensor(init_value, device=self.device) + # Initialize the Lagrange multiplier as a parameter self.lagrange = nn.Parameter( torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) ) - self.to(self.device) - def forward( self, lhs: Optional[torch.Tensor | float | int] = None, rhs: Optional[torch.Tensor | float | int] = None ) -> torch.Tensor: # Get the multiplier value based on parameterization - multiplier = torch.nn.functional.softplus(self.lagrange) - + # multiplier = torch.nn.functional.softplus(self.lagrange) + log_multiplier = torch.log(self.lagrange) + # Return the raw multiplier if no constraint values provided if lhs is None: - return multiplier + return log_multiplier.exp() # Convert inputs to tensors and move to device lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) @@ -536,9 +555,9 @@ class LagrangeMultiplier(nn.Module): diff = lhs - rhs - assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" + assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}" - return multiplier * diff + return log_multiplier.exp() * diff # numerically better def orthogonal_init(): diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index 97f040aaf..f4f4dba84 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -19,7 +19,7 @@ training: grad_clip_norm: 10.0 lr: 3e-4 - eval_freq: 10000 + eval_freq: 2500 log_freq: 500 save_freq: 50000 @@ -29,7 +29,7 @@ training: online_steps_between_rollouts: 1000 online_sampling_ratio: 1.0 online_env_seed: 10000 - online_buffer_capacity: 40000 + online_buffer_capacity: 10000 online_buffer_seed_size: 0 do_online_rollout_async: false @@ -70,9 +70,9 @@ policy: temperature_init: 1.0 num_critics: 2 num_subsample_critics: None - critic_lr: 3e-4 - actor_lr: 3e-4 - temperature_lr: 3e-4 + critic_lr: 1e-4 + actor_lr: 1e-4 + temperature_lr: 1e-4 critic_target_update_weight: 0.005 utd_ratio: 2