From e760e4cd638920afec33ce065ad5e36dc37b34eb Mon Sep 17 00:00:00 2001 From: Remi Date: Thu, 25 Apr 2024 11:47:38 +0200 Subject: [PATCH] Move normalization to policy for act and diffusion (#90) Co-authored-by: Alexander Soare --- README.md | 4 +- examples/1_load_hugging_face_dataset.py | 2 +- examples/3_evaluate_pretrained_policy.py | 2 - examples/4_train_policy.py | 5 +- lerobot/common/datasets/factory.py | 57 +---- lerobot/common/envs/utils.py | 18 +- .../common/policies/act/configuration_act.py | 68 ++++-- lerobot/common/policies/act/modeling_act.py | 51 +++-- .../diffusion/configuration_diffusion.py | 68 ++++-- .../policies/diffusion/modeling_diffusion.py | 42 ++-- lerobot/common/policies/factory.py | 6 +- lerobot/common/policies/normalize.py | 196 ++++++++++++++++++ lerobot/common/transforms.py | 65 ------ lerobot/configs/env/aloha.yaml | 2 - lerobot/configs/env/pusht.yaml | 2 - lerobot/configs/env/xarm.yaml | 2 - lerobot/configs/policy/act.yaml | 30 ++- lerobot/configs/policy/diffusion.yaml | 40 ++-- lerobot/configs/policy/tdmpc.yaml | 4 +- lerobot/scripts/eval.py | 22 +- lerobot/scripts/train.py | 3 +- lerobot/scripts/visualize_dataset.py | 6 +- tests/test_envs.py | 11 +- tests/test_examples.py | 2 +- tests/test_policies.py | 123 ++++++++++- 25 files changed, 543 insertions(+), 288 deletions(-) create mode 100644 lerobot/common/policies/normalize.py delete mode 100644 lerobot/common/transforms.py diff --git a/README.md b/README.md index a0045bf2..8b78ca3e 100644 --- a/README.md +++ b/README.md @@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need: - `config.yaml` which you can get from the `.hydra` directory of your training output folder. - `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). -- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`). To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): ``` to_upload ├── config.yaml - ├── model.pt - └── stats.pth + └── model.pt ``` With the folder prepared, run the following with a desired revision ID. diff --git a/examples/1_load_hugging_face_dataset.py b/examples/1_load_hugging_face_dataset.py index d249394a..ca66769c 100644 --- a/examples/1_load_hugging_face_dataset.py +++ b/examples/1_load_hugging_face_dataset.py @@ -44,7 +44,7 @@ from datasets import load_dataset # TODO(rcadene): list available datasets on lerobot page using `datasets` # download/load hugging face dataset in pyarrow format -hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 +hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10 # display name of dataset and its features # TODO(rcadene): update to make the print pretty diff --git a/examples/3_evaluate_pretrained_policy.py b/examples/3_evaluate_pretrained_policy.py index a892fa23..392ad1c6 100644 --- a/examples/3_evaluate_pretrained_policy.py +++ b/examples/3_evaluate_pretrained_policy.py @@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id)) config_path = folder / "config.yaml" weights_path = folder / "model.pt" -stats_path = folder / "stats.pth" # normalization stats # Override some config parameters to do with evaluation. overrides = [ @@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides) eval( cfg, out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", - stats_path=stats_path, ) diff --git a/examples/4_train_policy.py b/examples/4_train_policy.py index 1ccb40d6..2b1fafd5 100644 --- a/examples/4_train_policy.py +++ b/examples/4_train_policy.py @@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg) # If you're doing something different, you will likely need to change at least some of the defaults. cfg = DiffusionConfig() # TODO(alexander-soare): Remove LR scheduler from the policy. -policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) +policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats) policy.train() policy.to(device) @@ -62,7 +62,6 @@ while not done: done = True break -# Save the policy, configuration, and normalization stats for later use. +# Save the policy and configuration for later use. policy.save(output_directory / "model.pt") OmegaConf.save(hydra_cfg, output_directory / "config.yaml") -torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0fbfff65..9753cde7 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -2,18 +2,13 @@ import os from pathlib import Path import torch -from torchvision.transforms import v2 - -from lerobot.common.transforms import NormalizeTransform +from omegaconf import OmegaConf DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_dataset( cfg, - # set normalize=False to remove all transformations and keep images unnormalized in [0,255] - normalize=True, - stats_path=None, split="train", ): if cfg.env.name == "xarm": @@ -33,58 +28,26 @@ def make_dataset( else: raise ValueError(cfg.env.name) - transforms = None - if normalize: - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, - # min_max_from_spec - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - - if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": - stats = {} - # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation.state"] = {} - stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action"] = {} - stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - elif stats_path is None: - # load a first dataset to access precomputed stats - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - split="train", - root=DATA_DIR, - ) - stats = stats_dataset.stats - else: - stats = torch.load(stats_path) - - transforms = v2.Compose( - [ - NormalizeTransform( - stats, - in_keys=[ - "observation.state", - "action", - ], - mode=normalization_mode, - ), - ] - ) - delta_timestamps = cfg.policy.get("delta_timestamps") if delta_timestamps is not None: for key in delta_timestamps: if isinstance(delta_timestamps[key], str): delta_timestamps[key] = eval(delta_timestamps[key]) + # TODO(rcadene): add data augmentations + dataset = clsfunc( dataset_id=cfg.dataset_id, split=split, root=DATA_DIR, delta_timestamps=delta_timestamps, - transform=transforms, ) + if cfg.get("override_dataset_stats"): + for key, stats_dict in cfg.override_dataset_stats.items(): + for stats_type, listconfig in stats_dict.items(): + # example of stats_type: min, max, mean, std + stats = OmegaConf.to_container(listconfig, resolve=True) + dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + return dataset diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index dcce1bcc..17021880 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -1,10 +1,8 @@ import einops import torch -from lerobot.common.transforms import apply_inverse_transform - -def preprocess_observation(observation, transform=None): +def preprocess_observation(observation): # map to expected inputs for the policy obs = {} @@ -24,7 +22,7 @@ def preprocess_observation(observation, transform=None): assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w") + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = img.type(torch.float32) img /= 255 @@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None): # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() - # apply same transforms as in training - if transform is not None: - for key in obs: - obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]]) - return obs -def postprocess_action(action, transform=None): - action = action.to("cpu") - # action is a batch (num_env,action_dim) instead of an item (action_dim), - # we assume applying inverse transform on a batch works the same - action = apply_inverse_transform({"action": action}, transform)["action"].numpy() +def postprocess_action(action): + action = action.to("cpu").numpy() assert ( action.ndim == 2 ), "we assume dimensions are respectively the number of parallel envs, action dimensions" diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 211a8ed0..82280b2c 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig: Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `state_dim`, `action_dim` and `camera_names`. + Those are: `input_shapes` and 'output_shapes`. Args: - state_dim: Dimensionality of the observation state space (excluding images). - action_dim: Dimensionality of the action space. n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - camera_names: The (unique) set of names for the cameras. chunk_size: The size of the action prediction "chunks" in units of environment steps. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in - [0, 1]) for normalization. - image_normalization_std: Value by which to divide the input image pixels (after the mean has been - subtracted). + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.images.top" refers to an input from the + "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two availables + modes are "mean_std" which substracts the mean and divide by the standard + deviation and "min_max" which rescale in a [-1, 1] range. + unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from torchvision. @@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig: is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. """ - # Environment. - state_dim: int = 14 - action_dim: int = 14 - - # Inputs / output structure. + # Input / output structure. n_obs_steps: int = 1 - camera_names: tuple[str] = ("top",) chunk_size: int = 100 n_action_steps: int = 100 - # Vision preprocessing. - image_normalization_mean: tuple[float, float, float] = field( - default_factory=lambda: [0.485, 0.456, 0.406] + input_shapes: dict[str, list[str]] = field( + default_factory=lambda: { + "observation.images.top": [3, 480, 640], + "observation.state": [14], + } + ) + output_shapes: dict[str, list[str]] = field( + default_factory=lambda: { + "action": [14], + } + ) + + # Normalization / Unnormalization + normalize_input_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "mean_std", + } + ) + unnormalize_output_modes: dict[str, str] = field( + default_factory=lambda: { + "action": "mean_std", + } ) - image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) # Architecture. # Vision backbone. @@ -117,7 +138,10 @@ class ActionChunkingTransformerConfig: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - if self.camera_names != ["top"]: - raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.") - if len(set(self.camera_names)) != len(self.camera_names): - raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.") + # Check that there is only one image. + # TODO(alexander-soare): generalize this to multiple images. + if ( + sum(k.startswith("observation.images.") for k in self.input_shapes) != 1 + or "observation.images.top" not in self.input_shapes + ): + raise ValueError('For now, only "observation.images.top" is accepted for an image input.') diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index c1af4ef4..c2dd5bf7 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -15,12 +15,12 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision -import torchvision.transforms as transforms from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize class ActionChunkingTransformerPolicy(nn.Module): @@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module): name = "act" - def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the @@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module): if cfg is None: cfg = ActionChunkingTransformerConfig() self.cfg = cfg + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). @@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module): self.vae_encoder = _TransformerEncoder(cfg) self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.vae_encoder_robot_state_input_proj = nn.Linear( + cfg.input_shapes["observation.state"][0], cfg.d_model + ) # Projection layer for action (joint-space target) to hidden dimension. - self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.vae_encoder_action_input_proj = nn.Linear( + cfg.input_shapes["observation.state"][0], cfg.d_model + ) self.latent_dim = cfg.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) @@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module): ) # Backbone for image feature extraction. - self.image_normalizer = transforms.Normalize( - mean=cfg.image_normalization_mean, std=cfg.image_normalization_std - ) backbone_model = getattr(torchvision.models, cfg.vision_backbone)( replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], pretrained=cfg.use_pretrained_backbone, @@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, cfg.d_model, kernel_size=1 @@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module): self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) + self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0]) self._reset_parameters() self._create_optimizer() @@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module): queue is empty. """ self.eval() + + batch = self.normalize_inputs(batch) + if len(self._action_queue) == 0: # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively # has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) + actions = self._forward(batch)[0][: self.cfg.n_action_steps] + + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() def forward(self, batch, **_) -> dict[str, Tensor]: @@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module): """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.train() + + batch = self.normalize_inputs(batch) + loss_dict = self.forward(batch) + # TODO(rcadene): self.unnormalize_outputs(out_dict) loss = loss_dict["loss"] loss.backward() @@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module): "observation.images.{name}": (B, C, H, W) tensor of images. } """ - # Check that there is only one image. - # TODO(alexander-soare): generalize this to multiple images. - provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")} - if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0: - raise ValueError( - f"The following camera images are missing from the provided batch: {missing}. Check the " - "configuration parameter: `camera_names`." - ) - # Stack images in the order dictated by the camera names. + # Stack images in the order dictated by input_shapes. batch["observation.images"] = torch.stack( - [batch[f"observation.images.{name}"] for name in self.cfg.camera_names], + [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")], dim=-4, ) @@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module): # Camera observation features and positional embeddings. all_cam_features = [] all_cam_pos_embeds = [] - images = self.image_normalizer(batch["observation.images"]) - for cam_index in range(len(self.cfg.camera_names)): + images = batch["observation.images"] + for cam_index in range(images.shape[-4]): cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d8820a0b..9a725a56 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass @@ -8,21 +8,28 @@ class DiffusionConfig: Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `state_dim`, `action_dim` and `image_size`. + Those are: `input_shapes` and `output_shapes`. Args: - state_dim: Dimensionality of the observation state space (excluding images). - action_dim: Dimensionality of the action space. - image_size: (H, W) size of the input images. n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in - [0, 1]) for normalization. - image_normalization_std: Value by which to divide the input image pixels (after the mean has been - subtracted). + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.image" refers to an input from + a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two availables + modes are "mean_std" which substracts the mean and divide by the standard + deviation and "min_max" which rescale in a [-1, 1] range. + unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. @@ -58,20 +65,35 @@ class DiffusionConfig: spaced). If not provided, this defaults to be the same as `num_train_timesteps`. """ - # Environment. - # Inherit these from the environment config. - state_dim: int = 2 - action_dim: int = 2 - image_size: tuple[int, int] = (96, 96) - # Inputs / output structure. n_obs_steps: int = 2 horizon: int = 16 n_action_steps: int = 8 - # Vision preprocessing. - image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) - image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + input_shapes: dict[str, list[str]] = field( + default_factory=lambda: { + "observation.image": [3, 96, 96], + "observation.state": [2], + } + ) + output_shapes: dict[str, list[str]] = field( + default_factory=lambda: { + "action": [2], + } + ) + + # Normalization / Unnormalization + normalize_input_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "min_max", + } + ) + unnormalize_output_modes: dict[str, str] = field( + default_factory=lambda: { + "action": "min_max", + } + ) # Architecture / modeling. # Vision backbone. @@ -123,10 +145,14 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]: + if ( + self.crop_shape[0] > self.input_shapes["observation.image"][1] + or self.crop_shape[1] > self.input_shapes["observation.image"][2] + ): raise ValueError( - f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and " - f"{self.image_size} for `image_size`." + f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' + f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' + '`input_shapes["observation.image"]`.' ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index e7cc62f4..7a639375 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -13,7 +13,6 @@ import logging import math import time from collections import deque -from itertools import chain from typing import Callable import einops @@ -27,6 +26,7 @@ from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.utils import ( get_device_from_parameters, get_dtype_from_parameters, @@ -42,7 +42,9 @@ class DiffusionPolicy(nn.Module): name = "diffusion" - def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): + def __init__( + self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None + ): """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the @@ -54,6 +56,8 @@ class DiffusionPolicy(nn.Module): if cfg is None: cfg = DiffusionConfig() self.cfg = cfg + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -126,6 +130,8 @@ class DiffusionPolicy(nn.Module): assert "observation.state" in batch assert len(batch) == 2 + batch = self.normalize_inputs(batch) + self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: @@ -135,6 +141,10 @@ class DiffusionPolicy(nn.Module): actions = self.ema_diffusion.generate_actions(batch) else: actions = self.diffusion.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() @@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module): self.diffusion.train() + batch = self.normalize_inputs(batch) + loss = self.forward(batch)["loss"] loss.backward() + # TODO(rcadene): self.unnormalize_outputs(out_dict) + grad_norm = torch.nn.utils.clip_grad_norm_( self.diffusion.parameters(), self.cfg.grad_clip_norm, @@ -197,7 +211,8 @@ class _DiffusionUnetImagePolicy(nn.Module): self.rgb_encoder = _RgbEncoder(cfg) self.unet = _ConditionalUnet1D( - cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps + cfg, + global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps, ) self.noise_scheduler = DDPMScheduler( @@ -225,7 +240,7 @@ class _DiffusionUnetImagePolicy(nn.Module): # Sample prior. sample = torch.randn( - size=(batch_size, self.cfg.horizon, self.cfg.action_dim), + size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]), dtype=dtype, device=device, generator=generator, @@ -268,7 +283,7 @@ class _DiffusionUnetImagePolicy(nn.Module): sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.cfg.action_dim] + actions = sample[..., : self.cfg.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.cfg.n_action_steps @@ -346,12 +361,6 @@ class _RgbEncoder(nn.Module): def __init__(self, cfg: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)): - self.normalizer = nn.Identity() - else: - self.normalizer = torchvision.transforms.Normalize( - mean=cfg.image_normalization_mean, std=cfg.image_normalization_std - ) if cfg.crop_shape is not None: self.do_crop = True # Always use center crop for eval @@ -384,7 +393,9 @@ class _RgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:]) + feat_map_shape = tuple( + self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:] + ) self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) @@ -397,8 +408,7 @@ class _RgbEncoder(nn.Module): Returns: (B, D) image feature. """ - # Preprocess: normalize and maybe crop (if it was set up in the __init__). - x = self.normalizer(x) + # Preprocess: maybe crop (if it was set up in the __init__). if self.do_crop: if self.training: # noqa: SIM108 x = self.maybe_random_crop(x) @@ -502,7 +512,7 @@ class _ConditionalUnet1D(nn.Module): # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # just reverse these. - in_out = [(cfg.action_dim, cfg.down_dims[0])] + list( + in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list( zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) ) @@ -553,7 +563,7 @@ class _ConditionalUnet1D(nn.Module): self.final_conv = nn.Sequential( _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), - nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1), + nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1), ) def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9698175d..a8235388 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): return policy_cfg -def make_policy(hydra_cfg: DictConfig): +def make_policy(hydra_cfg: DictConfig, dataset_stats=None): if hydra_cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig): from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) - policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) + policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats) policy.to(get_safe_torch_device(hydra_cfg.device)) elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) - policy = ActionChunkingTransformerPolicy(policy_cfg) + policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats) policy.to(get_safe_torch_device(hydra_cfg.device)) else: raise ValueError(hydra_cfg.policy.name) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py new file mode 100644 index 00000000..4d230b16 --- /dev/null +++ b/lerobot/common/policies/normalize.py @@ -0,0 +1,196 @@ +import torch +from torch import nn + + +def create_stats_buffers(shapes, modes, stats=None): + """ + Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics. + + Parameters: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]). + These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height + and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among: + - "mean_std": substract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values + (e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time, + these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be + be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since + they are already in the policy state_dict. + + Returns: + dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to + `requires_grad=False`, suitable to not be updated during backpropagation. + """ + stats_buffers = {} + + for key, mode in modes.items(): + assert mode in ["mean_std", "min_max"] + + shape = tuple(shapes[key]) + + if "image" in key: + # sanity checks + assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" + c, h, w = shape + assert c < h and c < w, f"{key} is not channel first ({shape=})" + # override image shape to be invariant to height and width + shape = (c, 1, 1) + + # Note: we initialize mean, std, min, max to infinity. They should be overwritten + # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, + # we assert they are not infinity anymore. + + buffer = {} + if mode == "mean_std": + mean = torch.ones(shape, dtype=torch.float32) * torch.inf + std = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict( + { + "mean": nn.Parameter(mean, requires_grad=False), + "std": nn.Parameter(std, requires_grad=False), + } + ) + elif mode == "min_max": + min = torch.ones(shape, dtype=torch.float32) * torch.inf + max = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict( + { + "min": nn.Parameter(min, requires_grad=False), + "max": nn.Parameter(max, requires_grad=False), + } + ) + + if stats is not None: + if mode == "mean_std": + buffer["mean"].data = stats[key]["mean"] + buffer["std"].data = stats[key]["std"] + elif mode == "min_max": + buffer["min"].data = stats[key]["min"] + buffer["max"].data = stats[key]["max"] + + stats_buffers[key] = buffer + return stats_buffers + + +class Normalize(nn.Module): + """ + Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training. + + Parameters: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]). + These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height + and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among: + - "mean_std": substract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values + (e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time, + these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be + be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since + they are already in the policy state_dict. + """ + + def __init__(self, shapes, modes, stats=None): + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch): + for key, mode in self.modes.items(): + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"] + std = buffer["std"] + assert not torch.isinf( + mean + ).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf( + std + ).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif mode == "min_max": + min = buffer["min"] + max = buffer["max"] + assert not torch.isinf( + min + ).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf( + max + ).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + # normalize to [0,1] + batch[key] = (batch[key] - min) / (max - min) + # normalize to [-1, 1] + batch[key] = batch[key] * 2 - 1 + else: + raise ValueError(mode) + return batch + + +class Unnormalize(nn.Module): + """ + Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment. + + Parameters: + shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]). + These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height + and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among: + - "mean_std": multiply by standard deviation and add mean + - "min_max": go from [-1, 1] range to original range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values + (e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time, + these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be + be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since + they are already in the policy state_dict. + """ + + def __init__(self, shapes, modes, stats=None): + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch): + for key, mode in self.modes.items(): + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"] + std = buffer["std"] + assert not torch.isinf( + mean + ).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf( + std + ).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + batch[key] = batch[key] * std + mean + elif mode == "min_max": + min = buffer["min"] + max = buffer["max"] + assert not torch.isinf( + min + ).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf( + max + ).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min + else: + raise ValueError(mode) + return batch diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py deleted file mode 100644 index fffa835a..00000000 --- a/lerobot/common/transforms.py +++ /dev/null @@ -1,65 +0,0 @@ -from torchvision.transforms.v2 import Compose, Transform - - -def apply_inverse_transform(item, transform): - transforms = transform.transforms if isinstance(transform, Compose) else [transform] - for tf in transforms[::-1]: - if tf.invertible: - item = tf.inverse_transform(item) - else: - raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).") - return item - - -class NormalizeTransform(Transform): - invertible = True - - def __init__( - self, - stats: dict, - in_keys: list[str] = None, - out_keys: list[str] | None = None, - in_keys_inv: list[str] | None = None, - out_keys_inv: list[str] | None = None, - mode="mean_std", - ): - super().__init__() - self.in_keys = in_keys - self.out_keys = in_keys if out_keys is None else out_keys - self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv - self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv - self.stats = stats - assert mode in ["mean_std", "min_max"] - self.mode = mode - - def forward(self, item): - for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - if inkey not in item: - continue - if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] - item[outkey] = (item[inkey] - mean) / (std + 1e-8) - else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] - # normalize to [0,1] - item[outkey] = (item[inkey] - min) / (max - min) - # normalize to [-1, 1] - item[outkey] = item[outkey] * 2 - 1 - return item - - def inverse_transform(self, item): - for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): - if inkey not in item: - continue - if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] - item[outkey] = item[inkey] * std + mean - else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] - item[outkey] = (item[inkey] + 1) / 2 - item[outkey] = item[outkey] * (max - min) + min - return item diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 6b836795..26493711 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -20,7 +20,5 @@ env: image_size: [3, 480, 640] episode_length: 400 fps: ${fps} - -policy: state_dim: 14 action_dim: 14 diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index a7097ffd..92b6a33b 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -20,7 +20,5 @@ env: image_size: 96 episode_length: 300 fps: ${fps} - -policy: state_dim: 2 action_dim: 2 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index bcba659e..72ca12a0 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -19,7 +19,5 @@ env: image_size: 84 episode_length: 25 fps: ${fps} - -policy: state_dim: 4 action_dim: 4 diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index eb4e512b..cfde3b91 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -11,26 +11,36 @@ log_freq: 250 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon +override_dataset_stats: + observation.images.top: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + # See `configuration_act.py` for more details. policy: name: act pretrained_model_path: - # Environment. - # Inherit these from the environment config. - state_dim: ??? - action_dim: ??? - - # Inputs / output structure. + # Input / output structure. n_obs_steps: ${n_obs_steps} - camera_names: [top] # [top, front_close, left_pillar, right_pillar] chunk_size: 100 # chunk_size n_action_steps: 100 - # Vision preprocessing. - image_normalization_mean: [0.485, 0.456, 0.406] - image_normalization_std: [0.229, 0.224, 0.225] + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.top: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + normalize_input_modes: + observation.images.top: mean_std + observation.state: mean_std + unnormalize_output_modes: + action: mean_std # Architecture. # Vision backbone. diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 44746dfc..f844534e 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -18,27 +18,43 @@ online_steps: 0 offline_prioritized_sampler: true +override_dataset_stats: + # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? + observation.image: + mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model + # from the original codebase, but we should remove these and train our own pretrained model + observation.state: + min: [13.456424, 32.938293] + max: [496.14618, 510.9579] + action: + min: [12.0, 25.0] + max: [511.0, 511.0] + policy: name: diffusion pretrained_model_path: - # Environment. - # Inherit these from the environment config. - state_dim: ??? - action_dim: ??? - image_size: - - ${env.image_size} # height - - ${env.image_size} # width - - # Inputs / output structure. + # Input / output structure. n_obs_steps: ${n_obs_steps} horizon: ${horizon} n_action_steps: ${n_action_steps} - # Vision preprocessing. - image_normalization_mean: [0.5, 0.5, 0.5] - image_normalization_std: [0.5, 0.5, 0.5] + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.image: [3, 96, 96] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + normalize_input_modes: + observation.image: mean_std + observation.state: min_max + unnormalize_output_modes: + action: min_max # Architecture / modeling. # Vision backbone. diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 4fd2b6bb..c78a5d73 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -16,8 +16,8 @@ policy: frame_stack: 1 num_channels: 32 img_size: ${env.image_size} - state_dim: ??? - action_dim: ??? + state_dim: ${env.action_dim} + action_dim: ${env.action_dim} # planning mpc: true diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 32b7e26b..c66e7ee9 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download from PIL import Image as PILImage from tqdm import trange -from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation @@ -64,8 +63,6 @@ def eval_policy( policy: torch.nn.Module, max_episodes_rendered: int = 0, video_dir: Path = None, - # TODO(rcadene): make it possible to overwrite fps? we should use env.fps - transform: callable = None, return_episode_data: bool = False, seed=None, ): @@ -132,10 +129,6 @@ def eval_policy( if return_episode_data: observations.append(deepcopy(observation)) - # apply transform to normalize the observations - for key in observation: - observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) - # send observation to device/gpu observation = {key: observation[key].to(device, non_blocking=True) for key in observation} @@ -143,8 +136,8 @@ def eval_policy( with torch.inference_mode(): action = policy.select_action(observation, step=step) - # apply inverse transform to unnormalize the action - action = postprocess_action(action, transform) + # convert to cpu numpy + action = postprocess_action(action) # apply the next action observation, reward, terminated, truncated, info = env.step(action) @@ -360,7 +353,7 @@ def eval_policy( return info -def eval(cfg: dict, out_dir=None, stats_path=None): +def eval(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() @@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): log_output_dir(out_dir) - logging.info("Making transforms.") - # TODO(alexander-soare): Completely decouple datasets from evaluation. - transform = make_dataset(cfg, stats_path=stats_path).transform - logging.info("Making environment.") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) @@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): policy, max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", - transform=transform, return_episode_data=False, seed=cfg.seed, ) @@ -423,17 +411,13 @@ if __name__ == "__main__": if args.config is not None: # Note: For the config_path, Hydra wants a path relative to this script file. cfg = init_hydra_config(args.config, args.overrides) - # TODO(alexander-soare): Save and load stats in trained model directory. - stats_path = None elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision=args.revision)) cfg = init_hydra_config( folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] ) - stats_path = folder / "stats.pth" eval( cfg, out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", - stats_path=stats_path, ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1f4ee16a..c849cce8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) logging.info("make_policy") - policy = make_policy(cfg) + policy = make_policy(cfg, dataset_stats=offline_dataset.stats) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None): eval_info = eval_policy( rollout_env, policy, - transform=offline_dataset.transform, return_episode_data=True, seed=cfg.seed, ) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index b51e62b4..3d4d8c53 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None): log_output_dir(out_dir) logging.info("make_dataset") - dataset = make_dataset( - cfg, - # remove all transformations such as rescale images from [0,255] to [0,1] or normalization - normalize=False, - ) + dataset = make_dataset(cfg) logging.info("Start rendering episodes from offline buffer") video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER) diff --git a/tests/test_envs.py b/tests/test_envs.py index 33928a62..85363702 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -6,7 +6,6 @@ import torch from gymnasium.utils.env_checker import check_env import lerobot -from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation from lerobot.common.utils.utils import init_hydra_config @@ -38,12 +37,14 @@ def test_factory(env_name): overrides=[f"env={env_name}", f"device={DEVICE}"], ) - dataset = make_dataset(cfg) - env = make_env(cfg, num_parallel_envs=1) obs, _ = env.reset() - obs = preprocess_observation(obs, transform=dataset.transform) - for key in dataset.image_keys: + obs = preprocess_observation(obs) + + # test image keys are float32 in range [0,1] + for key in obs: + if "image" not in key: + continue img = obs[key] assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model diff --git a/tests/test_examples.py b/tests/test_examples.py index 3ac040b1..876735b4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -51,7 +51,7 @@ def test_examples_4_and_3(): # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. exec(file_contents, {}) - for file_name in ["model.pt", "stats.pth", "config.yaml"]: + for file_name in ["model.pt", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() path = "examples/3_evaluate_pretrained_policy.py" diff --git a/tests/test_policies.py b/tests/test_policies.py index ab679fcb..0e4ce654 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config - -from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env # TODO(aliberts): refactor using lerobot/__init__.py variables @@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides): ] + extra_overrides, ) + # Check that we can make the policy object. - policy = make_policy(cfg) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) # Check that the policy follows the required protocol. assert isinstance( policy, Policy ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." + # Check that we run select_actions and get the appropriate output. - dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=2) dataloader = torch.utils.data.DataLoader( @@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides): observation, _ = env.reset(seed=cfg.seed) # apply transform to normalize the observations - observation = preprocess_observation(observation, dataset.transform) + observation = preprocess_observation(observation) # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} @@ -86,8 +88,115 @@ def test_policy(env_name, policy_name, extra_overrides): with torch.inference_mode(): action = policy.select_action(observation, step=0) - # apply inverse transform to unnormalize the action - action = postprocess_action(action, dataset.transform) + # convert action to cpu numpy array + action = postprocess_action(action) # Test step through policy env.step(action) + + # Test load state_dict + if policy_name != "tdmpc": + # TODO(rcadene, alexander-soare): make it work for tdmpc + new_policy = make_policy(cfg) + new_policy.load_state_dict(policy.state_dict()) + + +@pytest.mark.parametrize( + "insert_temporal_dim", + [ + False, + True, + ], +) +def test_normalize(insert_temporal_dim): + """ + Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise + an exception when the forward pass is called without the stats having been provided. + + TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as + expected. + """ + + input_shapes = { + "observation.image": [3, 96, 96], + "observation.state": [10], + } + output_shapes = { + "action": [5], + } + + normalize_input_modes = { + "observation.image": "mean_std", + "observation.state": "min_max", + } + unnormalize_output_modes = { + "action": "min_max", + } + + dataset_stats = { + "observation.image": { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + "min": torch.randn(3, 1, 1), + "max": torch.randn(3, 1, 1), + }, + "observation.state": { + "mean": torch.randn(10), + "std": torch.randn(10), + "min": torch.randn(10), + "max": torch.randn(10), + }, + "action": { + "mean": torch.randn(5), + "std": torch.randn(5), + "min": torch.randn(5), + "max": torch.randn(5), + }, + } + + bsize = 2 + input_batch = { + "observation.image": torch.randn(bsize, 3, 96, 96), + "observation.state": torch.randn(bsize, 10), + } + output_batch = { + "action": torch.randn(bsize, 5), + } + + if insert_temporal_dim: + tdim = 4 + + for key in input_batch: + # [2,3,96,96] -> [2,tdim,3,96,96] + input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) + + for key in output_batch: + output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) + + # test without stats + normalize = Normalize(input_shapes, normalize_input_modes, stats=None) + with pytest.raises(AssertionError): + normalize(input_batch) + + # test with stats + normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats) + normalize(input_batch) + + # test loading pretrained models + new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None) + new_normalize.load_state_dict(normalize.state_dict()) + new_normalize(input_batch) + + # test without stats + unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) + with pytest.raises(AssertionError): + unnormalize(output_batch) + + # test with stats + unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats) + unnormalize(output_batch) + + # test loading pretrained models + new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) + new_unnormalize.load_state_dict(unnormalize.state_dict()) + unnormalize(output_batch)