Bug fix: missing attention mask in VAE encoder in ACT policy (#279)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -314,9 +314,23 @@ class ACT(nn.Module):
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_input_state else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
) # (bs, seq+1 or 2)
|
||||
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
vae_encoder_input.permute(1, 0, 2),
|
||||
pos_embed=pos_embed.permute(1, 0, 2),
|
||||
key_padding_mask=key_padding_mask,
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||
@@ -402,9 +416,11 @@ class ACTEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||
def forward(
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed)
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
@@ -427,12 +443,13 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
|
||||
x = x[0] # note: [0] to select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
|
||||
Reference in New Issue
Block a user