From 220b32441db26eadbe4fc1fe7175123ce36ece11 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 28 May 2024 20:55:02 +0000 Subject: [PATCH] Act works with n_obs > 1 --- .../common/policies/act/configuration_act.py | 4 -- lerobot/common/policies/act/modeling_act.py | 57 +++++++++++++++++-- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95374f4d..80759f66 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -151,7 +151,3 @@ class ACTConfig: f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - ) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 22683aff..649ea0ac 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -139,9 +139,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"] - l1_loss = l1_loss.unsqueeze(-1) - bsize = l1_loss.shape[0] + bsize = actions_hat.shape[0] + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") + l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1) l1_loss = l1_loss.view(bsize, -1).mean(dim=1) out_dict = {} @@ -237,6 +237,13 @@ class ACT(nn.Module): # Note: The forward method of this returns a dict: {"feature_map": output}. self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + if self.config.n_obs_steps > 1: + self.factorized_conv3d = FactorizedConv3d( + backbone_model.fc.in_features, + backbone_model.fc.in_features, + backbone_model.fc.in_features, + ) + # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) self.decoder = ACTDecoder(config) @@ -292,7 +299,7 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["action"].shape[0] + batch_size = batch["observation.images"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -329,6 +336,7 @@ class ACT(nn.Module): else: # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) @@ -338,8 +346,25 @@ class ACT(nn.Module): all_cam_features = [] all_cam_pos_embeds = [] images = batch["observation.images"] + for cam_index in range(images.shape[-4]): - cam_features = self.backbone(images[:, cam_index])["feature_map"] + if self.config.n_obs_steps >= 1: + assert images.ndim == 6 + assert images.shape[1] == self.config.n_obs_steps + cam_images = images[:, :, cam_index] + b, t, c_in, h_in, w_in = cam_images.shape + cam_images = cam_images.reshape(b * t, c_in, h_in, w_in) + cam_features = self.backbone(cam_images)["feature_map"] + bt_, c_out, h_out, w_out = cam_features.shape + cam_features = cam_features.view(b, t, c_out, h_out, w_out) + cam_features = cam_features.permute(0, 2, 1, 3, 4).contiguous() + cam_features = self.factorized_conv3d(cam_features) + cam_features = cam_features.mean(2) + else: + assert images.ndim == 5 + cam_image = images[:, cam_index] + cam_features = self.backbone(cam_image)["feature_map"] + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -371,6 +396,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer decoder_in = torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, @@ -620,3 +646,24 @@ def get_activation_fn(activation: str) -> Callable: if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") + + +class FactorizedConv3d(nn.Module): + def __init__(self, in_channels, hidden_dim, out_channels): + super().__init__() + self.in_channels = in_channels + self.hidden_dim = hidden_dim + self.out_channels = out_channels + self.spatial_conv = nn.Conv3d(in_channels, hidden_dim, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.temporal_conv = nn.Conv3d(hidden_dim, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0)) + + def forward(self, x): + assert x.ndim == 5, f"Expected shape is b,t,c,h,w but {x.shape} given." + assert ( + x.shape[1] == self.in_channels + ), f"Expected number channels as input is {self.in_channels}, but {x.shape[2]} given." + x = self.spatial_conv(x) + x = F.relu(x, inplace=True) + x = self.temporal_conv(x) + x = F.relu(x, inplace=True) + return x