Move normalize/unnormalize transforms to policy for act and diffusion

This commit is contained in:
Cadene
2024-04-20 21:08:14 +00:00
parent c1bcf857c5
commit 42ed7bb670
19 changed files with 145 additions and 195 deletions

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
@@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100
n_action_steps: int = 100
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {
"action": "mean_std",
}
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"

View File

@@ -15,12 +15,15 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
unnormalize_outputs,
)
class ActionChunkingTransformerPolicy(nn.Module):
@@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
@@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty.
"""
self.eval()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
@@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"]
loss.backward()
@@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"])
images = batch["observation.images"]
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)

View File

@@ -69,9 +69,14 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes: dict[str, str] = {
"action": "min_max",
}
# Architecture / modeling.
# Vision backbone.

View File

@@ -13,7 +13,6 @@ import logging
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
@@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
unnormalize_outputs,
)
@@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch
assert len(batch) == 2
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
@@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
@@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
@@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)

View File

@@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg
def make_policy(hydra_cfg: DictConfig):
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(hydra_cfg.policy.name)

View File

@@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def normalize_inputs(batch, stats, normalize_input_modes):
if normalize_input_modes is None:
return batch
for key, mode in normalize_input_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
if unnormalize_output_modes is None:
return batch
for key, mode in unnormalize_output_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch