Act works with n_obs > 1
This commit is contained in:
@@ -151,7 +151,3 @@ class ACTConfig:
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
@@ -139,9 +139,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"]
|
||||
l1_loss = l1_loss.unsqueeze(-1)
|
||||
bsize = l1_loss.shape[0]
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
@@ -237,6 +237,13 @@ class ACT(nn.Module):
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
if self.config.n_obs_steps > 1:
|
||||
self.factorized_conv3d = FactorizedConv3d(
|
||||
backbone_model.fc.in_features,
|
||||
backbone_model.fc.in_features,
|
||||
backbone_model.fc.in_features,
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
@@ -292,7 +299,7 @@ class ACT(nn.Module):
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = batch["action"].shape[0]
|
||||
batch_size = batch["observation.images"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -329,6 +336,7 @@ class ACT(nn.Module):
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
@@ -338,8 +346,25 @@ class ACT(nn.Module):
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
if self.config.n_obs_steps >= 1:
|
||||
assert images.ndim == 6
|
||||
assert images.shape[1] == self.config.n_obs_steps
|
||||
cam_images = images[:, :, cam_index]
|
||||
b, t, c_in, h_in, w_in = cam_images.shape
|
||||
cam_images = cam_images.reshape(b * t, c_in, h_in, w_in)
|
||||
cam_features = self.backbone(cam_images)["feature_map"]
|
||||
bt_, c_out, h_out, w_out = cam_features.shape
|
||||
cam_features = cam_features.view(b, t, c_out, h_out, w_out)
|
||||
cam_features = cam_features.permute(0, 2, 1, 3, 4).contiguous()
|
||||
cam_features = self.factorized_conv3d(cam_features)
|
||||
cam_features = cam_features.mean(2)
|
||||
else:
|
||||
assert images.ndim == 5
|
||||
cam_image = images[:, cam_index]
|
||||
cam_features = self.backbone(cam_image)["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
@@ -371,6 +396,7 @@ class ACT(nn.Module):
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
@@ -620,3 +646,24 @@ def get_activation_fn(activation: str) -> Callable:
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||
|
||||
|
||||
class FactorizedConv3d(nn.Module):
|
||||
def __init__(self, in_channels, hidden_dim, out_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_channels = out_channels
|
||||
self.spatial_conv = nn.Conv3d(in_channels, hidden_dim, kernel_size=(1, 3, 3), padding=(0, 1, 1))
|
||||
self.temporal_conv = nn.Conv3d(hidden_dim, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
def forward(self, x):
|
||||
assert x.ndim == 5, f"Expected shape is b,t,c,h,w but {x.shape} given."
|
||||
assert (
|
||||
x.shape[1] == self.in_channels
|
||||
), f"Expected number channels as input is {self.in_channels}, but {x.shape[2]} given."
|
||||
x = self.spatial_conv(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
x = self.temporal_conv(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user