forked from tangger/lerobot
Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -15,12 +15,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
@@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
@@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
self.vae_encoder = _TransformerEncoder(cfg)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
self.latent_dim = cfg.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
||||
@@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
@@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
||||
@@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
self._create_optimizer()
|
||||
@@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
@@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
@@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Check that there is only one image.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
|
||||
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
|
||||
raise ValueError(
|
||||
f"The following camera images are missing from the provided batch: {missing}. Check the "
|
||||
"configuration parameter: `camera_names`."
|
||||
)
|
||||
# Stack images in the order dictated by the camera names.
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
@@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = self.image_normalizer(batch["observation.images"])
|
||||
for cam_index in range(len(self.cfg.camera_names)):
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
|
||||
Reference in New Issue
Block a user