forked from tangger/lerobot
Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
committed by
Michel Aractingi
parent
267a837a2c
commit
9386892f8e
@@ -167,8 +167,12 @@ class SACPolicy(
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
@@ -451,11 +455,11 @@ class SACPolicy(
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
@@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
@@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||
|
||||
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
|
||||
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
@@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
@@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
@@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
@@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
@@ -840,12 +827,8 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
self.encoder_is_shared = encoder_is_shared
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
@@ -859,7 +842,6 @@ class Policy(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -868,7 +850,6 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -877,6 +858,8 @@ class Policy(nn.Module):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
if self.encoder_is_shared:
|
||||
obs_enc = obs_enc.detach()
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_proj(self.image_enc_layers(x))
|
||||
x = self.image_enc_layers(x)
|
||||
if self.freeze_image_encoder:
|
||||
x = x.detach()
|
||||
return self.image_enc_proj(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
@@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
@@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
if self.freeze_image_encoder:
|
||||
enc_feat = enc_feat.detach()
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
||||
Reference in New Issue
Block a user