|
|
|
|
@@ -97,7 +97,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|
|
|
|
self.eval()
|
|
|
|
|
|
|
|
|
|
batch = self.normalize_inputs(batch)
|
|
|
|
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
|
|
|
|
if len(self.expected_image_keys) > 0:
|
|
|
|
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
|
|
|
|
|
|
|
|
|
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
|
|
|
|
# the first action.
|
|
|
|
|
@@ -135,7 +136,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
|
|
|
"""Run the batch through the model and compute the loss for training or validation."""
|
|
|
|
|
batch = self.normalize_inputs(batch)
|
|
|
|
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
|
|
|
|
if len(self.expected_image_keys) > 0:
|
|
|
|
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
|
|
|
|
batch = self.normalize_targets(batch)
|
|
|
|
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
|
|
|
|
|
|
|
|
|
@@ -200,12 +202,14 @@ class ACT(nn.Module):
|
|
|
|
|
self.config = config
|
|
|
|
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
|
|
|
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
|
|
|
|
self.use_input_state = "observation.state" in config.input_shapes
|
|
|
|
|
self.use_robot_state = "observation.state" in config.input_shapes
|
|
|
|
|
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
|
|
|
|
self.use_env_state = "observation.environment_state" in config.input_shapes
|
|
|
|
|
if self.config.use_vae:
|
|
|
|
|
self.vae_encoder = ACTEncoder(config)
|
|
|
|
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
|
|
|
|
# Projection layer for joint-space configuration to hidden dimension.
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
|
|
|
|
config.input_shapes["observation.state"][0], config.dim_model
|
|
|
|
|
)
|
|
|
|
|
@@ -218,7 +222,7 @@ class ACT(nn.Module):
|
|
|
|
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
|
|
|
|
# dimension.
|
|
|
|
|
num_input_token_encoder = 1 + config.chunk_size
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
num_input_token_encoder += 1
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"vae_encoder_pos_enc",
|
|
|
|
|
@@ -226,34 +230,45 @@ class ACT(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Backbone for image feature extraction.
|
|
|
|
|
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
|
|
|
|
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
|
|
|
|
weights=config.pretrained_backbone_weights,
|
|
|
|
|
norm_layer=FrozenBatchNorm2d,
|
|
|
|
|
)
|
|
|
|
|
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
|
|
|
|
# map).
|
|
|
|
|
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
|
|
|
|
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
|
|
|
|
if self.use_images:
|
|
|
|
|
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
|
|
|
|
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
|
|
|
|
weights=config.pretrained_backbone_weights,
|
|
|
|
|
norm_layer=FrozenBatchNorm2d,
|
|
|
|
|
)
|
|
|
|
|
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
|
|
|
|
# feature map).
|
|
|
|
|
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
|
|
|
|
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
|
|
|
|
|
|
|
|
|
# Transformer (acts as VAE decoder when training with the variational objective).
|
|
|
|
|
self.encoder = ACTEncoder(config)
|
|
|
|
|
self.decoder = ACTDecoder(config)
|
|
|
|
|
|
|
|
|
|
# Transformer encoder input projections. The tokens will be structured like
|
|
|
|
|
# [latent, robot_state, image_feature_map_pixels].
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
self.encoder_robot_state_input_proj = nn.Linear(
|
|
|
|
|
config.input_shapes["observation.state"][0], config.dim_model
|
|
|
|
|
)
|
|
|
|
|
if self.use_env_state:
|
|
|
|
|
self.encoder_env_state_input_proj = nn.Linear(
|
|
|
|
|
config.input_shapes["observation.environment_state"][0], config.dim_model
|
|
|
|
|
)
|
|
|
|
|
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
|
|
|
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
|
|
|
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
|
|
|
|
)
|
|
|
|
|
if self.use_images:
|
|
|
|
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
|
|
|
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
|
|
|
|
)
|
|
|
|
|
# Transformer encoder positional embeddings.
|
|
|
|
|
num_input_token_decoder = 2 if self.use_input_state else 1
|
|
|
|
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
|
|
|
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
|
|
|
|
n_1d_tokens = 1 # for the latent
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
n_1d_tokens += 1
|
|
|
|
|
if self.use_env_state:
|
|
|
|
|
n_1d_tokens += 1
|
|
|
|
|
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
|
|
|
|
if self.use_images:
|
|
|
|
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
|
|
|
|
|
|
|
|
|
# Transformer decoder.
|
|
|
|
|
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
|
|
|
|
@@ -274,10 +289,13 @@ class ACT(nn.Module):
|
|
|
|
|
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
|
|
|
|
|
|
|
|
|
`batch` should have the following structure:
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
"observation.state": (B, state_dim) batch of robot states.
|
|
|
|
|
"observation.state" (optional): (B, state_dim) batch of robot states.
|
|
|
|
|
|
|
|
|
|
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
|
|
|
|
AND/OR
|
|
|
|
|
"observation.environment_state": (B, env_dim) batch of environment states.
|
|
|
|
|
|
|
|
|
|
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -291,7 +309,11 @@ class ACT(nn.Module):
|
|
|
|
|
"action" in batch
|
|
|
|
|
), "actions must be provided when using the variational objective in training mode."
|
|
|
|
|
|
|
|
|
|
batch_size = batch["observation.images"].shape[0]
|
|
|
|
|
batch_size = (
|
|
|
|
|
batch["observation.images"]
|
|
|
|
|
if "observation.images" in batch
|
|
|
|
|
else batch["observation.environment_state"]
|
|
|
|
|
).shape[0]
|
|
|
|
|
|
|
|
|
|
# Prepare the latent for input to the transformer encoder.
|
|
|
|
|
if self.config.use_vae and "action" in batch:
|
|
|
|
|
@@ -299,12 +321,12 @@ class ACT(nn.Module):
|
|
|
|
|
cls_embed = einops.repeat(
|
|
|
|
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
|
|
|
|
) # (B, 1, D)
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
|
|
|
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
|
|
|
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
|
|
|
|
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
|
|
|
|
else:
|
|
|
|
|
vae_encoder_input = [cls_embed, action_embed]
|
|
|
|
|
@@ -318,7 +340,7 @@ class ACT(nn.Module):
|
|
|
|
|
# sequence depending whether we use the input states or not (cls and robot state)
|
|
|
|
|
# False means not a padding token.
|
|
|
|
|
cls_joint_is_pad = torch.full(
|
|
|
|
|
(batch_size, 2 if self.use_input_state else 1),
|
|
|
|
|
(batch_size, 2 if self.use_robot_state else 1),
|
|
|
|
|
False,
|
|
|
|
|
device=batch["observation.state"].device,
|
|
|
|
|
)
|
|
|
|
|
@@ -347,56 +369,55 @@ class ACT(nn.Module):
|
|
|
|
|
batch["observation.state"].device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Prepare all other transformer encoder inputs.
|
|
|
|
|
# Prepare transformer encoder inputs.
|
|
|
|
|
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
|
|
|
|
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
|
|
|
|
# Robot state token.
|
|
|
|
|
if self.use_robot_state:
|
|
|
|
|
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
|
|
|
|
# Environment state token.
|
|
|
|
|
if self.use_env_state:
|
|
|
|
|
encoder_in_tokens.append(
|
|
|
|
|
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Camera observation features and positional embeddings.
|
|
|
|
|
all_cam_features = []
|
|
|
|
|
all_cam_pos_embeds = []
|
|
|
|
|
images = batch["observation.images"]
|
|
|
|
|
if self.use_images:
|
|
|
|
|
all_cam_features = []
|
|
|
|
|
all_cam_pos_embeds = []
|
|
|
|
|
images = batch["observation.images"]
|
|
|
|
|
|
|
|
|
|
for cam_index in range(images.shape[-4]):
|
|
|
|
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
|
|
|
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
|
|
|
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
|
|
|
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
|
|
|
|
all_cam_features.append(cam_features)
|
|
|
|
|
all_cam_pos_embeds.append(cam_pos_embed)
|
|
|
|
|
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
|
|
|
|
encoder_in = torch.cat(all_cam_features, axis=-1)
|
|
|
|
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
|
|
|
|
for cam_index in range(images.shape[-4]):
|
|
|
|
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
|
|
|
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
|
|
|
|
# buffer
|
|
|
|
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
|
|
|
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
|
|
|
|
all_cam_features.append(cam_features)
|
|
|
|
|
all_cam_pos_embeds.append(cam_pos_embed)
|
|
|
|
|
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
|
|
|
|
# and move to (sequence, batch, dim).
|
|
|
|
|
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
|
|
|
|
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
|
|
|
|
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
|
|
|
|
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
|
|
|
|
|
|
|
|
|
# Get positional embeddings for robot state and latent.
|
|
|
|
|
if self.use_input_state:
|
|
|
|
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
|
|
|
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
|
|
|
|
|
|
|
|
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
|
|
|
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
|
|
|
|
encoder_in = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
torch.stack(encoder_in_feats, axis=0),
|
|
|
|
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
pos_embed = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
|
|
|
|
|
cam_pos_embed.flatten(2).permute(2, 0, 1),
|
|
|
|
|
],
|
|
|
|
|
axis=0,
|
|
|
|
|
)
|
|
|
|
|
# Stack all tokens along the sequence dimension.
|
|
|
|
|
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
|
|
|
|
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
|
|
|
|
|
|
|
|
|
# Forward pass through the transformer modules.
|
|
|
|
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
|
|
|
|
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
|
|
|
|
|
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
|
|
|
|
decoder_in = torch.zeros(
|
|
|
|
|
(self.config.chunk_size, batch_size, self.config.dim_model),
|
|
|
|
|
dtype=pos_embed.dtype,
|
|
|
|
|
device=pos_embed.device,
|
|
|
|
|
dtype=encoder_in_pos_embed.dtype,
|
|
|
|
|
device=encoder_in_pos_embed.device,
|
|
|
|
|
)
|
|
|
|
|
decoder_out = self.decoder(
|
|
|
|
|
decoder_in,
|
|
|
|
|
encoder_out,
|
|
|
|
|
encoder_pos_embed=pos_embed,
|
|
|
|
|
encoder_pos_embed=encoder_in_pos_embed,
|
|
|
|
|
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|