revision
This commit is contained in:
@@ -32,12 +32,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
elif cfg.env.name == "aloha":
|
||||
import gym_aloha # noqa: F401
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
if cfg.env.task == "sim_transfer_cube":
|
||||
env_name = "gym_aloha/AlohaTransferCube-v0"
|
||||
elif cfg.env.task == "sim_insertion":
|
||||
env_name = "gym_aloha/AlohaInsertion-v0"
|
||||
else:
|
||||
raise ValueError(f"`{cfg.env.task}` has no environment implementation.")
|
||||
|
||||
env_fn = lambda: gym.make( # noqa: E731
|
||||
"gym_aloha/AlohaTransferCube-v0",
|
||||
**kwargs,
|
||||
)
|
||||
env_fn = lambda: gym.make(env_name, **kwargs) # noqa: E731
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
|
||||
@@ -337,18 +337,21 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
|
||||
# Prepare fixed positional embedding.
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||
|
||||
# Forward pass through VAE encoder.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # (B, D)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
# Use reparameterization trick to sample from the latent's PDF.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
@@ -469,7 +472,7 @@ class _TransformerEncoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x)[0]
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
@@ -563,7 +566,7 @@ class _TransformerDecoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[0]
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
@@ -575,7 +578,7 @@ class _TransformerDecoderLayer(nn.Module):
|
||||
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
||||
value=encoder_out,
|
||||
)[0]
|
||||
)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout2(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
@@ -634,7 +637,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
||||
Returns:
|
||||
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
||||
"""
|
||||
not_mask = torch.ones_like(x[0, [0]]) # (1, H, W)
|
||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
|
||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user