Make policies compatible with other/multiple image keys (#149)

This commit is contained in:
Alexander Soare
2024-05-16 13:51:53 +01:00
committed by GitHub
parent f52f4f2cd2
commit 68c1b13406
9 changed files with 107 additions and 69 deletions

View File

@@ -145,10 +145,3 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if config is None:
config = ACTConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
@@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None:
@@ -86,13 +92,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval()
batch = self.normalize_inputs(batch)
self._stack_images(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@@ -108,8 +111,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
@@ -132,21 +135,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
@@ -176,10 +164,10 @@ class ACT(nn.Module):
│ encoder │ │ │ │Transf.│ │
│ │ │ │ │encoder│ │
└───▲─────┘ │ │ │ │ │
│ │ │ └──▲──┘ │
│ │ │
inputs └─────┼─────┘
│ │ │ └──▲──┘ │
│ │ │
inputs └─────┼──┘ │ image emb.
state emb.
└───────────────────────┘
"""
@@ -321,18 +309,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(