Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig:
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `camera_names`.
Those are: `input_shapes` and 'output_shapes`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
@@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig:
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
# Environment.
state_dim: int = 14
action_dim: int = 14
# Inputs / output structure.
# Input / output structure.
n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Architecture.
# Vision backbone.
@@ -117,7 +138,10 @@ class ActionChunkingTransformerConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@@ -15,12 +15,12 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module):
@@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
@@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
@@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
@@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self._reset_parameters()
self._create_optimizer()
@@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
@@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
@@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
@@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)