From b2896d38f55823b9e5c7e8fea810ab646d4ade01 Mon Sep 17 00:00:00 2001 From: Jack Vial Date: Mon, 2 Sep 2024 13:29:27 -0400 Subject: [PATCH] fix(act): n_vae_encoder_layers config parameter wasn't being used (#400) --- lerobot/common/policies/act/modeling_act.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3427c482..418863a1 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -296,7 +296,7 @@ class ACT(nn.Module): self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) self.use_env_state = "observation.environment_state" in config.input_shapes if self.config.use_vae: - self.vae_encoder = ACTEncoder(config) + self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. if self.use_robot_state: @@ -521,9 +521,11 @@ class ACT(nn.Module): class ACTEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, config: ACTConfig): + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): super().__init__() - self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward(