WIP: Add no_state option + fix use_vae=False to ACT

This commit is contained in:
Remi Cadene
2024-05-27 13:20:32 +00:00
parent 3189ac26b8
commit c2fd06f765
2 changed files with 43 additions and 29 deletions

View File

@@ -139,25 +139,27 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"]
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) l1_loss = l1_loss.unsqueeze(-1)
).mean() bsize = l1_loss.shape[0]
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {}
out_dict["actions"] = self.unnormalize_outputs({"action": actions_hat})["action"]
out_dict["l1_loss"] = l1_loss
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = ( kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() out_dict["kld_loss"] = kld_loss
) out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else: else:
loss_dict["loss"] = l1_loss out_dict["loss"] = l1_loss
return loss_dict return out_dict
class ACT(nn.Module): class ACT(nn.Module):
@@ -200,25 +202,28 @@ class ACT(nn.Module):
self.config = config self.config = config
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.has_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear( if self.has_state:
config.input_shapes["observation.state"][0], config.dim_model self.vae_encoder_robot_state_input_proj = nn.Linear(
) config.input_shapes["observation.state"][0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear( self.vae_encoder_action_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.output_shapes["action"][0], config.dim_model
) )
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
@@ -238,15 +243,17 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear( if self.has_state:
config.input_shapes["observation.state"][0], config.dim_model self.encoder_robot_state_input_proj = nn.Linear(
) config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) num_input_token_decoder = 2 if self.has_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
@@ -285,7 +292,7 @@ class ACT(nn.Module):
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.state"].shape[0] batch_size = batch["action"].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@@ -293,11 +300,16 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( if self.has_state:
1 robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
if self.has_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding. # Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
@@ -337,13 +349,15 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) if self.has_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack([latent_embed, robot_state_embed], axis=0), torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"), einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
] ]
) )

View File

@@ -93,7 +93,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
policy.train() policy.train()
output_dict = policy.forward(batch) output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"] loss = output_dict["loss"].mean()
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),