Move normalize/unnormalize transforms to policy for act and diffusion
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = field(
|
||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||
)
|
||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "mean_std",
|
||||
}
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
|
||||
@@ -15,12 +15,15 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
normalize_inputs,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
@@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
@@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
@@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
@@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
@@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = self.image_normalizer(batch["observation.images"])
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(len(self.cfg.camera_names)):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
|
||||
@@ -69,9 +69,14 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "min_max",
|
||||
}
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -13,7 +13,6 @@ import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
@@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
normalize_inputs,
|
||||
populate_queues,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
@@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
@@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
@@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
@@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
@@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
@@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
|
||||
@@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def normalize_inputs(batch, stats, normalize_input_modes):
|
||||
if normalize_input_modes is None:
|
||||
return batch
|
||||
for key, mode in normalize_input_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||
if unnormalize_output_modes is None:
|
||||
return batch
|
||||
for key, mode in unnormalize_output_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user