diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 98b0df935..05c0e02a4 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -52,115 +52,13 @@ class SACPolicy( config.validate_features() self.config = config + # Determine action dimension and initialize all components continuous_action_dim = config.output_features["action"].shape[0] - - # Default to identity normalizations - self.normalize_inputs = nn.Identity() - self.normalize_targets = nn.Identity() - self.unnormalize_outputs = nn.Identity() - # Apply normalization if dataset stats provided - if config.dataset_stats: - params = _convert_normalization_params_to_tensor(config.dataset_stats) - self.normalize_inputs = Normalize( - config.input_features, - config.normalization_mapping, - params, - ) - stats = dataset_stats or params - self.normalize_targets = Normalize( - config.output_features, - config.normalization_mapping, - stats, - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, - config.normalization_mapping, - stats, - ) - - # NOTE: For images the encoder should be shared between the actor and critic - self.shared_encoder = config.shared_encoder - encoder_critic = SACObservationEncoder(config, self.normalize_inputs) - encoder_actor = ( - encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs) - ) - - # Create a list of critic heads - critic_heads = [ - CriticHead( - input_dim=encoder_critic.output_dim + continuous_action_dim, - **asdict(config.critic_network_kwargs), - ) - for _ in range(config.num_critics) - ] - - self.critic_ensemble = CriticEnsemble( - encoder=encoder_critic, - ensemble=critic_heads, - output_normalization=self.normalize_targets, - ) - - # Create target critic heads as deepcopies of the original critic heads - target_critic_heads = [ - CriticHead( - input_dim=encoder_critic.output_dim + continuous_action_dim, - **asdict(config.critic_network_kwargs), - ) - for _ in range(config.num_critics) - ] - - self.critic_target = CriticEnsemble( - encoder=encoder_critic, - ensemble=target_critic_heads, - output_normalization=self.normalize_targets, - ) - - self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) - - self.critic_ensemble = torch.compile(self.critic_ensemble) - self.critic_target = torch.compile(self.critic_target) - - self.grasp_critic = None - self.grasp_critic_target = None - - if config.num_discrete_actions is not None: - # Create grasp critic - self.grasp_critic = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - output_dim=config.num_discrete_actions, - **asdict(config.grasp_critic_network_kwargs), - ) - - # Create target grasp critic - self.grasp_critic_target = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - output_dim=config.num_discrete_actions, - **asdict(config.grasp_critic_network_kwargs), - ) - - self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) - - self.grasp_critic = torch.compile(self.grasp_critic) - self.grasp_critic_target = torch.compile(self.grasp_critic_target) - - self.actor = Policy( - encoder=encoder_actor, - network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), - action_dim=continuous_action_dim, - encoder_is_shared=config.shared_encoder, - **asdict(config.policy_kwargs), - ) - if config.target_entropy is None: - discrete_actions_dim: Literal[1] | Literal[0] = ( - 1 if config.num_discrete_actions is not None else 0 - ) - config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2) - - temperature_init = config.temperature_init - self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) - self.temperature = self.log_alpha.exp().item() + self._init_normalization(dataset_stats) + self._init_encoders() + self._init_critics(continuous_action_dim) + self._init_actor(continuous_action_dim) + self._init_temperature() def get_optim_params(self) -> dict: optim_params = { @@ -492,6 +390,101 @@ class SACPolicy( actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() return actor_loss + def _init_normalization(self, dataset_stats): + """Initialize input/output normalization modules.""" + self.normalize_inputs = nn.Identity() + self.normalize_targets = nn.Identity() + self.unnormalize_outputs = nn.Identity() + if self.config.dataset_stats: + params = _convert_normalization_params_to_tensor(self.config.dataset_stats) + self.normalize_inputs = Normalize( + self.config.input_features, self.config.normalization_mapping, params + ) + stats = dataset_stats or params + self.normalize_targets = Normalize( + self.config.output_features, self.config.normalization_mapping, stats + ) + self.unnormalize_outputs = Unnormalize( + self.config.output_features, self.config.normalization_mapping, stats + ) + + def _init_encoders(self): + """Initialize shared or separate encoders for actor and critic.""" + self.shared_encoder = self.config.shared_encoder + self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_actor = ( + self.encoder_critic + if self.shared_encoder + else SACObservationEncoder(self.config, self.normalize_inputs) + ) + + def _init_critics(self, continuous_action_dim): + """Build critic ensemble, targets, and optional grasp critic.""" + heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( + encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets + ) + target_heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble( + encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets + ) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + if self.config.num_discrete_actions is not None: + self._init_grasp_critics() + + def _init_grasp_critics(self): + """Build discrete grasp critic ensemble and target networks.""" + self.grasp_critic = GraspCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.grasp_critic_network_kwargs), + ) + self.grasp_critic_target = GraspCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.grasp_critic_network_kwargs), + ) + + # TODO: (maractingi, azouitine) Compile the grasp critic + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + def _init_actor(self, continuous_action_dim): + """Initialize policy actor network and default target entropy.""" + self.actor = Policy( + encoder=self.encoder_actor, + network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), + action_dim=continuous_action_dim, + encoder_is_shared=self.shared_encoder, + **asdict(self.config.policy_kwargs), + ) + if self.config.target_entropy is None: + dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) + self.config.target_entropy = -np.prod(dim) / 2 + + def _init_temperature(self): + """Set up temperature parameter and initial log_alpha.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + self.temperature = self.log_alpha.exp().item() + class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations."""