Act works with n_obs > 1

This commit is contained in:
Remi Cadene
2024-05-28 20:55:02 +00:00
parent 960589849f
commit 220b32441d
2 changed files with 52 additions and 9 deletions

View File

@@ -151,7 +151,3 @@ class ACTConfig:
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
) )
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)

View File

@@ -139,9 +139,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"] bsize = actions_hat.shape[0]
l1_loss = l1_loss.unsqueeze(-1) l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
bsize = l1_loss.shape[0] l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1) l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {} out_dict = {}
@@ -237,6 +237,13 @@ class ACT(nn.Module):
# Note: The forward method of this returns a dict: {"feature_map": output}. # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
if self.config.n_obs_steps > 1:
self.factorized_conv3d = FactorizedConv3d(
backbone_model.fc.in_features,
backbone_model.fc.in_features,
backbone_model.fc.in_features,
)
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config) self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config) self.decoder = ACTDecoder(config)
@@ -292,7 +299,7 @@ class ACT(nn.Module):
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = batch["action"].shape[0] batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@@ -329,6 +336,7 @@ class ACT(nn.Module):
else: else:
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device batch["observation.state"].device
) )
@@ -338,8 +346,25 @@ class ACT(nn.Module):
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
images = batch["observation.images"] images = batch["observation.images"]
for cam_index in range(images.shape[-4]): for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"] if self.config.n_obs_steps >= 1:
assert images.ndim == 6
assert images.shape[1] == self.config.n_obs_steps
cam_images = images[:, :, cam_index]
b, t, c_in, h_in, w_in = cam_images.shape
cam_images = cam_images.reshape(b * t, c_in, h_in, w_in)
cam_features = self.backbone(cam_images)["feature_map"]
bt_, c_out, h_out, w_out = cam_features.shape
cam_features = cam_features.view(b, t, c_out, h_out, w_out)
cam_features = cam_features.permute(0, 2, 1, 3, 4).contiguous()
cam_features = self.factorized_conv3d(cam_features)
cam_features = cam_features.mean(2)
else:
assert images.ndim == 5
cam_image = images[:, cam_index]
cam_features = self.backbone(cam_image)["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
@@ -371,6 +396,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules. # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model), (self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype, dtype=pos_embed.dtype,
@@ -620,3 +646,24 @@ def get_activation_fn(activation: str) -> Callable:
if activation == "glu": if activation == "glu":
return F.glu return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
class FactorizedConv3d(nn.Module):
def __init__(self, in_channels, hidden_dim, out_channels):
super().__init__()
self.in_channels = in_channels
self.hidden_dim = hidden_dim
self.out_channels = out_channels
self.spatial_conv = nn.Conv3d(in_channels, hidden_dim, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.temporal_conv = nn.Conv3d(hidden_dim, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0))
def forward(self, x):
assert x.ndim == 5, f"Expected shape is b,t,c,h,w but {x.shape} given."
assert (
x.shape[1] == self.in_channels
), f"Expected number channels as input is {self.in_channels}, but {x.shape[2]} given."
x = self.spatial_conv(x)
x = F.relu(x, inplace=True)
x = self.temporal_conv(x)
x = F.relu(x, inplace=True)
return x