diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index a3dc4ccbe..cc5bc0974 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -265,6 +265,16 @@ class ACT(nn.Module): self._reset_parameters() + self.register_buffer( + "latent_sample", + torch.zeros(1, config.latent_dim, dtype=torch.float32), + ) + + self.register_buffer( + "decoder_in", + torch.zeros(config.chunk_size, 1, config.dim_model, dtype=torch.float32), + ) + def _reset_parameters(self): """Xavier-uniform initialization of the transformer parameters as in the original code.""" for p in chain(self.encoder.parameters(), self.decoder.parameters()): @@ -329,10 +339,7 @@ class ACT(nn.Module): else: # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None - # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device - ) + latent_sample = self.latent_sample # Prepare all other transformer encoder inputs. # Camera observation features and positional embeddings. @@ -342,8 +349,7 @@ class ACT(nn.Module): for cam_index in range(images.shape[-4]): cam_features = self.backbone(images[:, cam_index])["feature_map"] - # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) @@ -374,12 +380,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) - # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer - decoder_in = torch.zeros( - (self.config.chunk_size, batch_size, self.config.dim_model), - dtype=pos_embed.dtype, - device=pos_embed.device, - ) + decoder_in = self.decoder_in decoder_out = self.decoder( decoder_in, encoder_out, @@ -579,6 +580,10 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): self._eps = 1e-6 # Inverse "common ratio" for the geometric progression in sinusoid frequencies. self._temperature = 10000 + self.register_buffer( + "inverse_frequency", + self._temperature ** (2 * (torch.arange(self.dimension, dtype=torch.float32) // 2) / self.dimension), + ) def forward(self, x: Tensor) -> Tensor: """ @@ -590,8 +595,8 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): not_mask = torch.ones_like(x[0, :1]) # (1, H, W) # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. - y_range = not_mask.cumsum(1, dtype=torch.float32) - x_range = not_mask.cumsum(2, dtype=torch.float32) + y_range = not_mask.cumsum(1, dtype=x.dtype) + x_range = not_mask.cumsum(2, dtype=x.dtype) # "Normalize" the position index such that it ranges in [0, 2π]. # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range @@ -599,9 +604,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi - inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension - ) + inverse_frequency = self.inverse_frequency x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)