diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72ebdd7ad..22683affa 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -139,25 +139,27 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"] + l1_loss = l1_loss.unsqueeze(-1) + bsize = l1_loss.shape[0] + l1_loss = l1_loss.view(bsize, -1).mean(dim=1) + + out_dict = {} + out_dict["actions"] = self.unnormalize_outputs({"action": actions_hat})["action"] + out_dict["l1_loss"] = l1_loss - loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld.item() - loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) + out_dict["kld_loss"] = kld_loss + out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight else: - loss_dict["loss"] = l1_loss + out_dict["loss"] = l1_loss - return loss_dict + return out_dict class ACT(nn.Module): @@ -200,25 +202,28 @@ class ACT(nn.Module): self.config = config # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + self.has_state = "observation.state" in config.input_shapes + self.latent_dim = config.latent_dim if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.has_state: + self.vae_encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model + config.output_shapes["action"][0], config.dim_model ) - self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # dimension. + num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. @@ -238,15 +243,17 @@ class ACT(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.has_state: + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + num_input_token_decoder = 2 if self.has_state else 1 + self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. @@ -285,7 +292,7 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["observation.state"].shape[0] + batch_size = batch["action"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -293,11 +300,16 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( - 1 - ) # (B, 1, D) + if self.has_state: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + if self.has_state: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) # Prepare fixed positional embedding. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. @@ -337,13 +349,15 @@ class ACT(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + if self.has_state: + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed] encoder_in = torch.cat( [ - torch.stack([latent_embed, robot_state_embed], axis=0), + torch.stack(encoder_in_feats, axis=0), einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3c..c81647f3b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,7 +93,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): policy.train() output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] + loss = output_dict["loss"].mean() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(),