Add real-world support for ACT on Aloha/Aloha2 (#228)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-05-31 15:31:02 +02:00
committed by GitHub
parent 504d2aaf48
commit d585c73f9f
22 changed files with 525 additions and 140 deletions

View File

@@ -200,25 +200,29 @@ class ACT(nn.Module):
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
if self.use_input_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
config.output_shapes["action"][0], config.dim_model
)
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
@@ -238,15 +242,17 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
if self.use_input_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -285,7 +291,7 @@ class ACT(nn.Module):
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.state"].shape[0]
batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@@ -293,11 +299,16 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
if self.use_input_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
if self.use_input_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
@@ -308,16 +319,17 @@ class ACT(nn.Module):
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -326,8 +338,10 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@@ -337,13 +351,15 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
@@ -357,6 +373,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,