Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -13,7 +13,6 @@ import logging
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
@@ -27,6 +26,7 @@ from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
@@ -42,7 +42,9 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -54,6 +56,8 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -126,6 +130,8 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
@@ -135,6 +141,10 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
@@ -197,7 +211,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
self.rgb_encoder = _RgbEncoder(cfg)
self.unet = _ConditionalUnet1D(
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
@@ -225,7 +240,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
dtype=dtype,
device=device,
generator=generator,
@@ -268,7 +283,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.action_dim]
actions = sample[..., : self.cfg.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.cfg.n_action_steps
@@ -346,12 +361,6 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
@@ -384,7 +393,9 @@ class _RgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -397,8 +408,7 @@ class _RgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
@@ -502,7 +512,7 @@ class _ConditionalUnet1D(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
)
@@ -553,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: