forked from tangger/lerobot
WIP: Add no_state option + fix use_vae=False to ACT
This commit is contained in:
@@ -139,25 +139,27 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"]
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
l1_loss = l1_loss.unsqueeze(-1)
|
||||||
).mean()
|
bsize = l1_loss.shape[0]
|
||||||
|
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
out_dict["actions"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||||
|
out_dict["l1_loss"] = l1_loss
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (
|
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
out_dict["kld_loss"] = kld_loss
|
||||||
)
|
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
out_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
@@ -200,25 +202,28 @@ class ACT(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
self.has_state = "observation.state" in config.input_shapes
|
||||||
|
self.latent_dim = config.latent_dim
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
if self.has_state:
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
)
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
|
)
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
self.vae_encoder_action_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.output_shapes["action"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.latent_dim = config.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
|
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
@@ -238,15 +243,17 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
if self.has_state:
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
)
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
num_input_token_decoder = 2 if self.has_state else 1
|
||||||
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
@@ -285,7 +292,7 @@ class ACT(nn.Module):
|
|||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["action"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
@@ -293,11 +300,16 @@ class ACT(nn.Module):
|
|||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
if self.has_state:
|
||||||
1
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
if self.has_state:
|
||||||
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
|
else:
|
||||||
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
@@ -337,13 +349,15 @@ class ACT(nn.Module):
|
|||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
if self.has_state:
|
||||||
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
|||||||
policy.train()
|
policy.train()
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
loss = output_dict["loss"]
|
loss = output_dict["loss"].mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(),
|
policy.parameters(),
|
||||||
|
|||||||
Reference in New Issue
Block a user