WIP faster act

This commit is contained in:
Remi Cadene
2024-06-09 11:57:06 +00:00
parent b65247feee
commit 33d149000a

View File

@@ -265,6 +265,16 @@ class ACT(nn.Module):
self._reset_parameters()
self.register_buffer(
"latent_sample",
torch.zeros(1, config.latent_dim, dtype=torch.float32),
)
self.register_buffer(
"decoder_in",
torch.zeros(config.chunk_size, 1, config.dim_model, dtype=torch.float32),
)
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
@@ -329,10 +339,7 @@ class ACT(nn.Module):
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = self.latent_sample
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
@@ -342,8 +349,7 @@ class ACT(nn.Module):
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
@@ -374,12 +380,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
device=pos_embed.device,
)
decoder_in = self.decoder_in
decoder_out = self.decoder(
decoder_in,
encoder_out,
@@ -579,6 +580,10 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
self.register_buffer(
"inverse_frequency",
self._temperature ** (2 * (torch.arange(self.dimension, dtype=torch.float32) // 2) / self.dimension),
)
def forward(self, x: Tensor) -> Tensor:
"""
@@ -590,8 +595,8 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
y_range = not_mask.cumsum(1, dtype=x.dtype)
x_range = not_mask.cumsum(2, dtype=x.dtype)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
@@ -599,9 +604,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
inverse_frequency = self.inverse_frequency
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)