fix(act): disable VAE during offline inference (#1588)

Prevent VAE inference when running in offline mode. In the lerobot dataset, the presence of the 'action' field incorrectly triggers the VAE inference block. This leads to a RuntimeError due to mismatched tensor dimensions (3 vs 2) when concatenating cls_embed, robot_state_embed, and action_embed—since action_embed lacks the chunk_size dimension. Additionally, this aligns with the original paper, where variational inference is skipped during inference.
This commit is contained in:
Adil Zouitine
2025-07-24 17:09:12 +02:00
committed by GitHub
parent 989f3d05ba
commit 4c8f002055

View File

@@ -420,7 +420,7 @@ class ACT(nn.Module):
batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
if self.config.use_vae and "action" in batch and self.training:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size