forked from tangger/lerobot
WIP faster act
This commit is contained in:
@@ -265,6 +265,16 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"latent_sample",
|
||||||
|
torch.zeros(1, config.latent_dim, dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"decoder_in",
|
||||||
|
torch.zeros(config.chunk_size, 1, config.dim_model, dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||||
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
||||||
@@ -329,10 +339,7 @@ class ACT(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
latent_sample = self.latent_sample
|
||||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
|
||||||
batch["observation.state"].device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare all other transformer encoder inputs.
|
# Prepare all other transformer encoder inputs.
|
||||||
# Camera observation features and positional embeddings.
|
# Camera observation features and positional embeddings.
|
||||||
@@ -342,8 +349,7 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features)
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
@@ -374,12 +380,7 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Forward pass through the transformer modules.
|
# Forward pass through the transformer modules.
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
decoder_in = self.decoder_in
|
||||||
decoder_in = torch.zeros(
|
|
||||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
|
||||||
dtype=pos_embed.dtype,
|
|
||||||
device=pos_embed.device,
|
|
||||||
)
|
|
||||||
decoder_out = self.decoder(
|
decoder_out = self.decoder(
|
||||||
decoder_in,
|
decoder_in,
|
||||||
encoder_out,
|
encoder_out,
|
||||||
@@ -579,6 +580,10 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
|||||||
self._eps = 1e-6
|
self._eps = 1e-6
|
||||||
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
||||||
self._temperature = 10000
|
self._temperature = 10000
|
||||||
|
self.register_buffer(
|
||||||
|
"inverse_frequency",
|
||||||
|
self._temperature ** (2 * (torch.arange(self.dimension, dtype=torch.float32) // 2) / self.dimension),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -590,8 +595,8 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
|||||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||||
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
||||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
y_range = not_mask.cumsum(1, dtype=x.dtype)
|
||||||
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
x_range = not_mask.cumsum(2, dtype=x.dtype)
|
||||||
|
|
||||||
# "Normalize" the position index such that it ranges in [0, 2π].
|
# "Normalize" the position index such that it ranges in [0, 2π].
|
||||||
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
||||||
@@ -599,9 +604,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
|||||||
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
||||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||||
|
|
||||||
inverse_frequency = self._temperature ** (
|
inverse_frequency = self.inverse_frequency
|
||||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
|
||||||
)
|
|
||||||
|
|
||||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||||
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user