Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -15,12 +15,12 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module):
@@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
@@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
@@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
@@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self._reset_parameters()
self._create_optimizer()
@@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
@@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
@@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
@@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)