diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3427c4829..418863a14 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -296,7 +296,7 @@ class ACT(nn.Module): self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) self.use_env_state = "observation.environment_state" in config.input_shapes if self.config.use_vae: - self.vae_encoder = ACTEncoder(config) + self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. if self.use_robot_state: @@ -521,9 +521,11 @@ class ACT(nn.Module): class ACTEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, config: ACTConfig): + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): super().__init__() - self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward(