forked from tangger/lerobot
WIP faster act
This commit is contained in:
@@ -265,6 +265,16 @@ class ACT(nn.Module):
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.register_buffer(
|
||||
"latent_sample",
|
||||
torch.zeros(1, config.latent_dim, dtype=torch.float32),
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"decoder_in",
|
||||
torch.zeros(config.chunk_size, 1, config.dim_model, dtype=torch.float32),
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
||||
@@ -329,10 +339,7 @@ class ACT(nn.Module):
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
latent_sample = self.latent_sample
|
||||
|
||||
# Prepare all other transformer encoder inputs.
|
||||
# Camera observation features and positional embeddings.
|
||||
@@ -342,8 +349,7 @@ class ACT(nn.Module):
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
@@ -374,12 +380,7 @@ class ACT(nn.Module):
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
)
|
||||
decoder_in = self.decoder_in
|
||||
decoder_out = self.decoder(
|
||||
decoder_in,
|
||||
encoder_out,
|
||||
@@ -579,6 +580,10 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
self._eps = 1e-6
|
||||
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
||||
self._temperature = 10000
|
||||
self.register_buffer(
|
||||
"inverse_frequency",
|
||||
self._temperature ** (2 * (torch.arange(self.dimension, dtype=torch.float32) // 2) / self.dimension),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
@@ -590,8 +595,8 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
||||
y_range = not_mask.cumsum(1, dtype=x.dtype)
|
||||
x_range = not_mask.cumsum(2, dtype=x.dtype)
|
||||
|
||||
# "Normalize" the position index such that it ranges in [0, 2π].
|
||||
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
||||
@@ -599,9 +604,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
)
|
||||
inverse_frequency = self.inverse_frequency
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
|
||||
Reference in New Issue
Block a user