Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -13,7 +13,6 @@ import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
@@ -27,6 +26,7 @@ from torch import Tensor, nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
@@ -42,7 +42,9 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -54,6 +56,8 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
@@ -126,6 +130,8 @@ class DiffusionPolicy(nn.Module):
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
@@ -135,6 +141,10 @@ class DiffusionPolicy(nn.Module):
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
@@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
@@ -197,7 +211,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||
|
||||
self.rgb_encoder = _RgbEncoder(cfg)
|
||||
self.unet = _ConditionalUnet1D(
|
||||
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
|
||||
cfg,
|
||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
@@ -225,7 +240,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -268,7 +283,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.cfg.action_dim]
|
||||
actions = sample[..., : self.cfg.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.cfg.n_action_steps
|
||||
@@ -346,12 +361,6 @@ class _RgbEncoder(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
@@ -384,7 +393,9 @@ class _RgbEncoder(nn.Module):
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -397,8 +408,7 @@ class _RgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
@@ -502,7 +512,7 @@ class _ConditionalUnet1D(nn.Module):
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
|
||||
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
|
||||
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
@@ -553,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user