Act works with n_obs > 1

This commit is contained in:
Remi Cadene
2024-05-28 20:55:02 +00:00
parent 960589849f
commit 220b32441d
2 changed files with 52 additions and 9 deletions

View File

@@ -151,7 +151,3 @@ class ACTConfig:
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)

View File

@@ -139,9 +139,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"]
l1_loss = l1_loss.unsqueeze(-1)
bsize = l1_loss.shape[0]
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {}
@@ -237,6 +237,13 @@ class ACT(nn.Module):
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
if self.config.n_obs_steps > 1:
self.factorized_conv3d = FactorizedConv3d(
backbone_model.fc.in_features,
backbone_model.fc.in_features,
backbone_model.fc.in_features,
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config)
@@ -292,7 +299,7 @@ class ACT(nn.Module):
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = batch["action"].shape[0]
batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@@ -329,6 +336,7 @@ class ACT(nn.Module):
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -338,8 +346,25 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
if self.config.n_obs_steps >= 1:
assert images.ndim == 6
assert images.shape[1] == self.config.n_obs_steps
cam_images = images[:, :, cam_index]
b, t, c_in, h_in, w_in = cam_images.shape
cam_images = cam_images.reshape(b * t, c_in, h_in, w_in)
cam_features = self.backbone(cam_images)["feature_map"]
bt_, c_out, h_out, w_out = cam_features.shape
cam_features = cam_features.view(b, t, c_out, h_out, w_out)
cam_features = cam_features.permute(0, 2, 1, 3, 4).contiguous()
cam_features = self.factorized_conv3d(cam_features)
cam_features = cam_features.mean(2)
else:
assert images.ndim == 5
cam_image = images[:, cam_index]
cam_features = self.backbone(cam_image)["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@@ -371,6 +396,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
@@ -620,3 +646,24 @@ def get_activation_fn(activation: str) -> Callable:
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
class FactorizedConv3d(nn.Module):
def __init__(self, in_channels, hidden_dim, out_channels):
super().__init__()
self.in_channels = in_channels
self.hidden_dim = hidden_dim
self.out_channels = out_channels
self.spatial_conv = nn.Conv3d(in_channels, hidden_dim, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.temporal_conv = nn.Conv3d(hidden_dim, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0))
def forward(self, x):
assert x.ndim == 5, f"Expected shape is b,t,c,h,w but {x.shape} given."
assert (
x.shape[1] == self.in_channels
), f"Expected number channels as input is {self.in_channels}, but {x.shape[2]} given."
x = self.spatial_conv(x)
x = F.relu(x, inplace=True)
x = self.temporal_conv(x)
x = F.relu(x, inplace=True)
return x