diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 872d6dd8..c9bd90fc 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -179,7 +179,7 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9705d517..fdfee0f9 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -71,7 +71,7 @@ class SACPolicy( "temperature": self.log_alpha, } if self.config.num_discrete_actions is not None: - optim_params["grasp_critic"] = self.grasp_critic.parameters() + optim_params["discrete_critic"] = self.discrete_critic.parameters() return optim_params def reset(self): @@ -90,7 +90,7 @@ class SACPolicy( actions = self.unnormalize_outputs({"action": actions})["action"] if self.config.num_discrete_actions is not None: - discrete_action_value = self.grasp_critic(batch, observations_features) + discrete_action_value = self.discrete_critic(batch, observations_features) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) @@ -118,8 +118,10 @@ class SACPolicy( q_values = critics(observations, actions, observation_features) return q_values - def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: - """Forward pass through a grasp critic network + def discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network Args: observations: Dictionary of observations @@ -127,16 +129,16 @@ class SACPolicy( observation_features: Optional pre-computed observation features to avoid recomputing encoder output Returns: - Tensor of Q-values from the grasp critic network + Tensor of Q-values from the discrete critic network """ - grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic - q_values = grasp_critic(observations, observation_features) + discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + q_values = discrete_critic(observations, observation_features) return q_values def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic", + model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -149,7 +151,7 @@ class SACPolicy( - done: Done mask tensor - observation_feature: Optional pre-computed observation features - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature") + model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") Returns: The computed loss tensor @@ -178,14 +180,14 @@ class SACPolicy( return {"loss_critic": loss_critic} - if model == "grasp_critic" and self.config.num_discrete_actions is not None: + if model == "discrete_critic" and self.config.num_discrete_actions is not None: # Extract critic-specific components rewards: Tensor = batch["reward"] next_observations: dict[str, Tensor] = batch["next_state"] done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") complementary_info = batch.get("complementary_info") - loss_grasp_critic = self.compute_loss_grasp_critic( + loss_discrete_critic = self.compute_loss_discrete_critic( observations=observations, actions=actions, rewards=rewards, @@ -195,7 +197,7 @@ class SACPolicy( next_observation_features=next_observation_features, complementary_info=complementary_info, ) - return {"loss_grasp_critic": loss_grasp_critic} + return {"loss_discrete_critic": loss_discrete_critic} if model == "actor": return { "loss_actor": self.compute_loss_actor( @@ -227,8 +229,8 @@ class SACPolicy( ) if self.config.num_discrete_actions is not None: for target_param, param in zip( - self.grasp_critic_target.parameters(), - self.grasp_critic.parameters(), + self.discrete_critic_target.parameters(), + self.discrete_critic.parameters(), strict=False, ): target_param.data.copy_( @@ -302,7 +304,7 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_grasp_critic( + def compute_loss_discrete_critic( self, observations, actions, @@ -320,46 +322,46 @@ class SACPolicy( actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() - gripper_penalties: Tensor | None = None + discrete_penalties: Tensor | None = None if complementary_info is not None: - gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") + discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network - next_grasp_qs = self.grasp_critic_forward( + next_discrete_qs = self.discrete_critic_forward( next_observations, use_target=False, observation_features=next_observation_features ) - best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) # Get target Q-values from target network - target_next_grasp_qs = self.grasp_critic_forward( + target_next_discrete_qs = self.discrete_critic_forward( observations=next_observations, use_target=True, observation_features=next_observation_features, ) # Use gather to select Q-values for best actions - target_next_grasp_q = torch.gather( - target_next_grasp_qs, dim=1, index=best_next_grasp_action + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action ).squeeze(-1) # Compute target Q-value with Bellman equation - rewards_gripper = rewards - if gripper_penalties is not None: - rewards_gripper = rewards + gripper_penalties - target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q # Get predicted Q-values for current observations - predicted_grasp_qs = self.grasp_critic_forward( + predicted_discrete_qs = self.discrete_critic_forward( observations=observations, use_target=False, observation_features=observation_features ) # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1) + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) # Compute MSE loss between predicted and target Q-values - grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) - return grasp_critic_loss + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" @@ -419,7 +421,7 @@ class SACPolicy( ) def _init_critics(self, continuous_action_dim): - """Build critic ensemble, targets, and optional grasp critic.""" + """Build critic ensemble, targets, and optional discrete critic.""" heads = [ CriticHead( input_dim=self.encoder_critic.output_dim + continuous_action_dim, @@ -446,25 +448,25 @@ class SACPolicy( self.critic_target = torch.compile(self.critic_target) if self.config.num_discrete_actions is not None: - self._init_grasp_critics() + self._init_discrete_critics() - def _init_grasp_critics(self): - """Build discrete grasp critic ensemble and target networks.""" - self.grasp_critic = GraspCritic( + def _init_discrete_critics(self): + """Build discrete discrete critic ensemble and target networks.""" + self.discrete_critic = DiscreteCritic( encoder=self.encoder_critic, input_dim=self.encoder_critic.output_dim, output_dim=self.config.num_discrete_actions, - **asdict(self.config.grasp_critic_network_kwargs), + **asdict(self.config.discrete_critic_network_kwargs), ) - self.grasp_critic_target = GraspCritic( + self.discrete_critic_target = DiscreteCritic( encoder=self.encoder_critic, input_dim=self.encoder_critic.output_dim, output_dim=self.config.num_discrete_actions, - **asdict(self.config.grasp_critic_network_kwargs), + **asdict(self.config.discrete_critic_network_kwargs), ) - # TODO: (maractingi, azouitine) Compile the grasp critic - self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + # TODO: (maractingi, azouitine) Compile the discrete critic + self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) def _init_actor(self, continuous_action_dim): """Initialize policy actor network and default target entropy.""" @@ -590,7 +592,7 @@ class SACObservationEncoder(nn.Module): This function processes image observations through the vision encoder once and returns the resulting features. When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and - reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes. + reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes. Performance impact: - The vision encoder forward pass is typically the main computational bottleneck during training and inference @@ -794,7 +796,7 @@ class CriticEnsemble(nn.Module): return q_values -class GraspCritic(nn.Module): +class DiscreteCritic(nn.Module): def __init__( self, encoder: nn.Module, diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 5785815c..e95485de 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -833,7 +833,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): if self.gripper_penalty_in_reward: reward += gripper_penalty else: - info["gripper_penalty"] = gripper_penalty + info["discrete_penalty"] = gripper_penalty return obs, reward, terminated, truncated, info diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 11b2b605..62beae7c 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -410,7 +410,7 @@ def add_actor_information_and_train( "complementary_info": batch["complementary_info"], } - # Use the forward method for critic loss (includes both main critic and grasp critic) + # Use the forward method for critic loss (includes both main critic and discrete critic) critic_output = policy.forward(forward_batch, model="critic") # Main critic optimization @@ -422,16 +422,16 @@ def add_actor_information_and_train( ) optimizers["critic"].step() - # Grasp critic optimization (if available) + # Discrete critic optimization (if available) if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") - loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value ) - optimizers["grasp_critic"].step() + optimizers["discrete_critic"].step() # Update target networks policy.update_target_networks() @@ -468,7 +468,7 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss (includes both main critic and grasp critic) + # Use the forward method for critic loss (includes both main critic and discrete critic) critic_output = policy.forward(forward_batch, model="critic") # Main critic optimization @@ -486,20 +486,20 @@ def add_actor_information_and_train( "critic_grad_norm": critic_grad_norm, } - # Grasp critic optimization (if available) + # Discrete critic optimization (if available) if policy.config.num_discrete_actions is not None: - discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") - loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value ).item() - optimizers["grasp_critic"].step() + optimizers["discrete_critic"].step() - # Add grasp critic info to training info - training_infos["loss_grasp_critic"] = loss_grasp_critic.item() - training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Add discrete critic info to training info + training_infos["loss_discrete_critic"] = loss_discrete_critic.item() + training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm # Actor and temperature optimization (at specified frequency) if optimization_step % policy_update_freq == 0: @@ -782,8 +782,8 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) if cfg.policy.num_discrete_actions is not None: - optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr + optimizer_discrete_critic = torch.optim.Adam( + params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None @@ -793,7 +793,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): "temperature": optimizer_temperature, } if cfg.policy.num_discrete_actions is not None: - optimizers["grasp_critic"] = optimizer_grasp_critic + optimizers["discrete_critic"] = optimizer_discrete_critic return optimizers, lr_scheduler