Refactor modeling_sac and parameter handling for clarity and reusability.

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine
2025-04-14 14:00:57 +00:00
committed by Michel Aractingi
parent 267a837a2c
commit 9386892f8e
2 changed files with 67 additions and 45 deletions

View File

@@ -167,8 +167,12 @@ class SACPolicy(
def get_optim_params(self) -> dict:
optim_params = {
"actor": self.actor.parameters_to_optimize,
"critic": self.critic_ensemble.parameters_to_optimize,
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
@@ -451,11 +455,11 @@ class SACPolicy(
target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Compute target Q-value with Bellman equation
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(
@@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
@@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
@@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
@@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
@@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
observations: dict[str, torch.Tensor],
@@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
else:
orthogonal_init()(self.output_layer.weight)
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.net.parameters())
self.parameters_to_optimize += list(self.output_layer.parameters())
def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
@@ -840,12 +827,8 @@ class Policy(nn.Module):
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
self.encoder_is_shared = encoder_is_shared
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
@@ -859,7 +842,6 @@ class Policy(nn.Module):
else:
orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
@@ -868,7 +850,6 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
def forward(
self,
@@ -877,6 +858,8 @@ class Policy(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
if self.encoder_is_shared:
obs_enc = obs_enc.detach()
# Get network outputs
outputs = self.network(obs_enc)
@@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
nn.Tanh(),
)
self.parameters_to_optimize = []
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
self.freeze_image_encoder = config.freeze_vision_encoder
def forward(self, x):
return self.image_enc_proj(self.image_enc_layers(x))
x = self.image_enc_layers(x)
if self.freeze_image_encoder:
x = x.detach()
return self.image_enc_proj(x)
class PretrainedImageEncoder(nn.Module):
@@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(),
)
self.parameters_to_optimize = []
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
self.freeze_image_encoder = config.freeze_vision_encoder
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
@@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
if self.freeze_image_encoder:
enc_feat = enc_feat.detach()
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat