From 8469d13681108f5f40c22595cc5a37cbcbe99803 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 18 Feb 2025 08:28:13 +0000 Subject: [PATCH] Added possibility to cache the embedding of the images when the encoder choice is pretrained and frozen Co-authored-by: Adil Zouitine --- lerobot/common/policies/sac/modeling_sac.py | 52 +++++++++++---------- lerobot/configs/policy/sac_maniskill.yaml | 8 ++-- lerobot/scripts/server/learner_server.py | 46 ++++++++++++++++-- 3 files changed, 74 insertions(+), 32 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 622919b9..02d08c35 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -148,7 +148,7 @@ class SACPolicy( return actions def critic_forward( - self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False + self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, features: Optional[Tensor] = None ) -> Tensor: """Forward pass through a critic network ensemble @@ -161,7 +161,7 @@ class SACPolicy( Tensor of Q-values from all critics """ critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions) + q_values = critics(observations, actions, features) return q_values def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... @@ -175,14 +175,14 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor: + def compute_loss_critic(self, observations, actions, rewards, next_observations, done, obs_features=None, next_obs_features=None) -> Tensor: temperature = self.log_alpha.exp().item() with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations) + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_obs_features) # 2- compute q targets q_targets = self.critic_forward( - observations=next_observations, actions=next_action_preds, use_target=True + observations=next_observations, actions=next_action_preds, use_target=True, features=next_obs_features ) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -199,7 +199,7 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - q_preds = self.critic_forward(observations, actions, use_target=False) + q_preds = self.critic_forward(observations, actions, use_target=False, features=obs_features) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. @@ -214,18 +214,18 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_temperature(self, observations) -> Tensor: + def compute_loss_temperature(self, observations, obs_features=None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss with torch.no_grad(): - _, log_probs, _ = self.actor(observations) + _, log_probs, _ = self.actor(observations, obs_features) temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() return temperature_loss - def compute_loss_actor(self, observations) -> Tensor: + def compute_loss_actor(self, observations, obs_features=None) -> Tensor: temperature = self.log_alpha.exp().item() - actions_pi, log_probs, _ = self.actor(observations) + actions_pi, log_probs, _ = self.actor(observations, obs_features) q_preds = self.critic_forward(observations, actions_pi, use_target=False) min_q_preds = q_preds.min(dim=0)[0] @@ -360,18 +360,19 @@ class CriticEnsemble(nn.Module): self, observations: dict[str, torch.Tensor], actions: torch.Tensor, + features: Optional[torch.Tensor] = None, ) -> torch.Tensor: device = get_device_from_parameters(self) - # Move each tensor in observations to device + # Move observations to the correct device observations = {k: v.to(device) for k, v in observations.items()} - # NOTE: We normalize actions it helps for sample efficiency - actions: dict[str, torch.tensor] = {"action": actions} - # NOTE: Normalization layer took dict in input and outputs a dict that why + # Normalize actions for sample efficiency + actions: dict[str, torch.Tensor] = {"action": actions} actions = self.output_normalization(actions)["action"] actions = actions.to(device) - - obs_enc = observations if self.encoder is None else self.encoder(observations) - + + # Use precomputed features if provided; otherwise, encode observations. + obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations)) + inputs = torch.cat([obs_enc, actions], dim=-1) list_q_values = [] for network, output_layer in zip(self.network_list, self.output_layers, strict=False): @@ -435,19 +436,20 @@ class Policy(nn.Module): def forward( self, observations: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists - obs_enc = observations if self.encoder is None else self.encoder(observations) - + features: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Use precomputed features if provided; otherwise compute encoder representations. + obs_enc = features if features is not None else (observations if self.encoder is None else self.encoder(observations)) + # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" - + if self.use_tanh_squash: log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) @@ -455,8 +457,8 @@ class Policy(nn.Module): log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: log_std = self.fixed_std.expand_as(means) - - # uses tanh activation function to squash the action to be in the range of [-1, 1] + + # Get distribution and sample actions normal = torch.distributions.Normal(means, torch.exp(log_std)) x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1)) log_probs = normal.log_prob(x_t) # Base log probability before Tanh diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 8a36947c..6b2436a1 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -52,10 +52,10 @@ policy: n_action_steps: 1 shared_encoder: true - vision_encoder_name: null - # vision_encoder_name: "helper2424/resnet10" - # freeze_vision_encoder: true - freeze_vision_encoder: false + # vision_encoder_name: null + vision_encoder_name: "helper2424/resnet10" + freeze_vision_encoder: true + # freeze_vision_encoder: false input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index faa7a0e7..6cba2abe 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -384,6 +384,21 @@ def add_actor_information_and_train( done = batch["done"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + # Precompute encoder features from the frozen vision encoder if enabled + obs_features, next_obs_features = None, None + if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder: + with torch.no_grad(): + obs_features = ( + policy.actor.encoder(observations) + if policy.actor.encoder is not None + else None + ) + next_obs_features = ( + policy.actor.encoder(next_observations) + if policy.actor.encoder is not None + else None + ) + with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -391,6 +406,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + obs_features=obs_features, # pass precomputed features + next_obs_features=next_obs_features, # for target computation ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -412,6 +429,21 @@ def add_actor_information_and_train( check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + # Precompute encoder features from the frozen vision encoder if enabled + obs_features, next_obs_features = None, None + if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder: + with torch.no_grad(): + obs_features = ( + policy.actor.encoder(observations) + if policy.actor.encoder is not None + else None + ) + next_obs_features = ( + policy.actor.encoder(next_observations) + if policy.actor.encoder is not None + else None + ) + with policy_lock: loss_critic = policy.compute_loss_critic( observations=observations, @@ -419,6 +451,8 @@ def add_actor_information_and_train( rewards=rewards, next_observations=next_observations, done=done, + obs_features=obs_features, # pass precomputed features + next_obs_features=next_obs_features, # for target computation ) optimizers["critic"].zero_grad() loss_critic.backward() @@ -430,7 +464,10 @@ def add_actor_information_and_train( if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): with policy_lock: - loss_actor = policy.compute_loss_actor(observations=observations) + loss_actor = policy.compute_loss_actor( + observations=observations, + obs_features=obs_features, # reuse precomputed features here + ) optimizers["actor"].zero_grad() loss_actor.backward() @@ -438,7 +475,10 @@ def add_actor_information_and_train( training_infos["loss_actor"] = loss_actor.item() - loss_temperature = policy.compute_loss_temperature(observations=observations) + loss_temperature = policy.compute_loss_temperature( + observations=observations, + obs_features=obs_features, # and for temperature loss as well + ) optimizers["temperature"].zero_grad() loss_temperature.backward() optimizers["temperature"].step() @@ -582,7 +622,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) # compile policy - policy = torch.compile(policy) + # policy = torch.compile(policy) assert isinstance(policy, nn.Module) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)