From 72751b7cf617ee645db55c84dab056fc1b159c98 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 24 Apr 2024 15:40:09 +0000 Subject: [PATCH] make load_state_dict work --- .../common/policies/act/configuration_act.py | 36 +++- lerobot/common/policies/act/modeling_act.py | 20 +- .../diffusion/configuration_diffusion.py | 35 +++- .../policies/diffusion/modeling_diffusion.py | 17 +- lerobot/common/policies/normalize.py | 174 ++++++++++++++++++ lerobot/common/policies/utils.py | 55 ------ lerobot/configs/policy/act.yaml | 7 + lerobot/configs/policy/diffusion.yaml | 7 + tests/test_policies.py | 112 ++++++++++- 9 files changed, 376 insertions(+), 87 deletions(-) create mode 100644 lerobot/common/policies/normalize.py diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 6fc8b72a..be904425 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -21,10 +21,24 @@ class ActionChunkingTransformerConfig: This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in - [0, 1]) for normalization. - image_normalization_std: Value by which to divide the input image pixels (after the mean has been - subtracted). + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.images.top" refers to an input from the + "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs. + The key represents the input data name, and the value specifies the type of normalization to apply. + Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize + between -1 and 1). + unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs. + This parameter maps output data types to their unnormalization modes, allowing the results to be + transformed back from a normalized state to a standard state. It is typically used when output + data needs to be interpreted in its original scale or units. For example, for "action", the + unnormalization mode might be "mean_std" or "min_max". vision_backbone: Name of the torchvision resnet backbone to use for encoding images. use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from torchvision. @@ -51,6 +65,7 @@ class ActionChunkingTransformerConfig: """ # Environment. + # TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes state_dim: int = 14 action_dim: int = 14 @@ -60,6 +75,18 @@ class ActionChunkingTransformerConfig: chunk_size: int = 100 n_action_steps: int = 100 + input_shapes: dict[str, str] = field( + default_factory=lambda: { + "observation.images.top": [3, 480, 640], + "observation.state": [14], + } + ) + output_shapes: dict[str, str] = field( + default_factory=lambda: { + "action": [14], + } + ) + # Normalization / Unnormalization normalize_input_modes: dict[str, str] = field( default_factory=lambda: { @@ -72,6 +99,7 @@ class ActionChunkingTransformerConfig: "action": "mean_std", } ) + # Architecture. # Vision backbone. vision_backbone: str = "resnet18" diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index eef38ff8..3682598f 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,11 +20,7 @@ from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig -from lerobot.common.policies.utils import ( - normalize_inputs, - to_buffer_dict, - unnormalize_outputs, -) +from lerobot.common.policies.normalize import Normalize, Unnormalize class ActionChunkingTransformerPolicy(nn.Module): @@ -76,9 +72,10 @@ class ActionChunkingTransformerPolicy(nn.Module): if cfg is None: cfg = ActionChunkingTransformerConfig() self.cfg = cfg - self.dataset_stats = to_buffer_dict(dataset_stats) self.normalize_input_modes = cfg.normalize_input_modes self.unnormalize_output_modes = cfg.unnormalize_output_modes + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). @@ -174,7 +171,7 @@ class ActionChunkingTransformerPolicy(nn.Module): """ self.eval() - batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + batch = self.normalize_inputs(batch) if len(self._action_queue) == 0: # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively @@ -182,9 +179,7 @@ class ActionChunkingTransformerPolicy(nn.Module): actions = self._forward(batch)[0][: self.cfg.n_action_steps] # TODO(rcadene): make _forward return output dictionary? - out_dict = {"action": actions} - out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes) - actions = out_dict["action"] + actions = self.unnormalize_outputs({"action": actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() @@ -218,9 +213,10 @@ class ActionChunkingTransformerPolicy(nn.Module): start_time = time.time() self.train() - batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + batch = self.normalize_inputs(batch) + loss_dict = self.forward(batch) - # TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes) + # TODO(rcadene): self.unnormalize_outputs(out_dict) loss = loss_dict["loss"] loss.backward() diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 18f371d4..79652342 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -19,10 +19,24 @@ class DiffusionConfig: horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in - [0, 1]) for normalization. - image_normalization_std: Value by which to divide the input image pixels (after the mean has been - subtracted). + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.image" refers to an input from + a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs. + The key represents the input data name, and the value specifies the type of normalization to apply. + Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize + between -1 and 1). + unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs. + This parameter maps output data types to their unnormalization modes, allowing the results to be + transformed back from a normalized state to a standard state. It is typically used when output + data needs to be interpreted in its original scale or units. For example, for "action", the + unnormalization mode might be "mean_std" or "min_max". vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. @@ -60,6 +74,7 @@ class DiffusionConfig: # Environment. # Inherit these from the environment config. + # TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes state_dim: int = 2 action_dim: int = 2 image_size: tuple[int, int] = (96, 96) @@ -69,6 +84,18 @@ class DiffusionConfig: horizon: int = 16 n_action_steps: int = 8 + input_shapes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": [3, 96, 96], + "observation.state": [2], + } + ) + output_shapes: dict[str, str] = field( + default_factory=lambda: { + "action": [2], + } + ) + # Normalization / Unnormalization normalize_input_modes: dict[str, str] = field( default_factory=lambda: { diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 32e89a1e..4bedf373 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -26,13 +26,11 @@ from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.utils import ( get_device_from_parameters, get_dtype_from_parameters, - normalize_inputs, populate_queues, - to_buffer_dict, - unnormalize_outputs, ) @@ -58,9 +56,10 @@ class DiffusionPolicy(nn.Module): if cfg is None: cfg = DiffusionConfig() self.cfg = cfg - self.dataset_stats = to_buffer_dict(dataset_stats) self.normalize_input_modes = cfg.normalize_input_modes self.unnormalize_output_modes = cfg.unnormalize_output_modes + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -133,7 +132,7 @@ class DiffusionPolicy(nn.Module): assert "observation.state" in batch assert len(batch) == 2 - batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + batch = self.normalize_inputs(batch) self._queues = populate_queues(self._queues, batch) @@ -146,9 +145,7 @@ class DiffusionPolicy(nn.Module): actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? - out_dict = {"action": actions} - out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes) - actions = out_dict["action"] + actions = self.unnormalize_outputs({"action": actions})["action"] self._queues["action"].extend(actions.transpose(0, 1)) @@ -166,12 +163,12 @@ class DiffusionPolicy(nn.Module): self.diffusion.train() - batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + batch = self.normalize_inputs(batch) loss = self.forward(batch)["loss"] loss.backward() - # TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes) + # TODO(rcadene): self.unnormalize_outputs(out_dict) grad_norm = torch.nn.utils.clip_grad_norm_( self.diffusion.parameters(), diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py new file mode 100644 index 00000000..f61066eb --- /dev/null +++ b/lerobot/common/policies/normalize.py @@ -0,0 +1,174 @@ +import torch +from torch import nn + + +def create_stats_buffers(shapes, modes, stats=None): + """ + This function generates buffers to store the mean and standard deviation, or minimum and maximum values, + used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can + be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers + are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable + for normalization purposes. + + If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height + and width, assuming a channel-first (c, h, w) format. + + Parameters: + shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. + modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max". + stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for + "mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. + It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden + by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, + without requiring to initialize the dataset used to train the model just to acess the `stats`. + + Returns: + dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing + the appropriate buffers for normalization. + """ + stats_buffers = {} + + for key, mode in modes.items(): + assert mode in ["mean_std", "min_max"] + + shape = shapes[key] + + # override shape to be invariant to height and width + if "image" in key: + # assume shape is channel first (b, c, h, w) or (b, t, c, h, w) + shape[-1] = 1 + shape[-2] = 1 + + buffer = {} + if mode == "mean_std": + mean = torch.zeros(shape, dtype=torch.float32) + std = torch.ones(shape, dtype=torch.float32) + buffer = nn.ParameterDict( + { + "mean": nn.Parameter(mean, requires_grad=False), + "std": nn.Parameter(std, requires_grad=False), + } + ) + elif mode == "min_max": + # TODO(rcadene): should we assume input is in [-1, 1] range? + min = torch.ones(shape, dtype=torch.float32) * -1 + max = torch.ones(shape, dtype=torch.float32) + buffer = nn.ParameterDict( + { + "min": nn.Parameter(min, requires_grad=False), + "max": nn.Parameter(max, requires_grad=False), + } + ) + + if stats is not None: + if mode == "mean_std": + buffer["mean"].data = stats[key]["mean"] + buffer["std"].data = stats[key]["std"] + elif mode == "min_max": + buffer["min"].data = stats[key]["min"] + buffer["max"].data = stats[key]["max"] + + stats_buffers[key] = buffer + return stats_buffers + + +class Normalize(nn.Module): + """ + A PyTorch module for normalizing data based on predefined statistics. + + The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based + on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data, + with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported: + - "mean_std": Normalizes data using the mean and standard deviation. + - "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range. + + Parameters: + shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. + modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max". + stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for + "mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. + It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden + by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, + without requiring to initialize the dataset used to train the model just to acess the `stats`. + """ + + def __init__(self, shapes, modes, stats=None): + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch): + for key, mode in self.modes.items(): + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"].unsqueeze(0) + std = buffer["std"].unsqueeze(0) + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif mode == "min_max": + min = buffer["min"].unsqueeze(0) + max = buffer["max"].unsqueeze(0) + # normalize to [0,1] + batch[key] = (batch[key] - min) / (max - min) + # normalize to [-1, 1] + batch[key] = batch[key] * 2 - 1 + else: + raise ValueError(mode) + return batch + + +class Unnormalize(nn.Module): + """ + A PyTorch module for unnormalizing data based on predefined statistics. + + The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based + on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data, + with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported: + - "mean_std": Unnormalizes data using the mean and standard deviation. + - "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range. + + Parameters: + shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. + modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max". + stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for + "mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. + It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden + by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, + without requiring to initialize the dataset used to train the model just to acess the `stats`. + """ + + def __init__(self, shapes, modes, stats=None): + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch): + for key, mode in self.modes.items(): + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"].unsqueeze(0) + std = buffer["std"].unsqueeze(0) + batch[key] = batch[key] * std + mean + elif mode == "min_max": + min = buffer["min"].unsqueeze(0) + max = buffer["max"].unsqueeze(0) + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min + else: + raise ValueError(mode) + return batch diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index f5640266..b23c1336 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -28,58 +28,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: Note: assumes that all parameters have the same dtype. """ return next(iter(module.parameters())).dtype - - -def normalize_inputs(batch, stats, normalize_input_modes): - if normalize_input_modes is None: - return batch - for key, mode in normalize_input_modes.items(): - if mode == "mean_std": - mean = stats[key]["mean"].unsqueeze(0) - std = stats[key]["std"].unsqueeze(0) - batch[key] = (batch[key] - mean) / (std + 1e-8) - elif mode == "min_max": - min = stats[key]["min"].unsqueeze(0) - max = stats[key]["max"].unsqueeze(0) - # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min) - # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 - else: - raise ValueError(mode) - return batch - - -def unnormalize_outputs(batch, stats, unnormalize_output_modes): - if unnormalize_output_modes is None: - return batch - for key, mode in unnormalize_output_modes.items(): - if mode == "mean_std": - mean = stats[key]["mean"].unsqueeze(0) - std = stats[key]["std"].unsqueeze(0) - batch[key] = batch[key] * std + mean - elif mode == "min_max": - min = stats[key]["min"].unsqueeze(0) - max = stats[key]["max"].unsqueeze(0) - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min - else: - raise ValueError(mode) - return batch - - -def to_buffer_dict(dataset_stats): - # TODO(rcadene): replace this function by `torch.BufferDict` when it exists - # see: https://github.com/pytorch/pytorch/issues/37386 - # TODO(rcadene): make `to_buffer_dict` generic and add docstring - if dataset_stats is None: - return None - - new_ds_stats = {} - for key, stats_dict in dataset_stats.items(): - new_stats_dict = {} - for stats_type, value in stats_dict.items(): - # set requires_grad=False to have the same behavior as a nn.Buffer - new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False) - new_ds_stats[key] = nn.ParameterDict(new_stats_dict) - return nn.ParameterDict(new_ds_stats) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 9428e232..6fd7467f 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -34,6 +34,13 @@ policy: chunk_size: 100 # chunk_size n_action_steps: 100 + input_shapes: + # TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env? + observation.images.top: [3, 480, 640] + observation.state: ["${policy.state_dim}"] + output_shapes: + action: ["${policy.action_dim}"] + # Normalization / Unnormalization normalize_input_modes: observation.images.top: mean_std diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 1dc14104..d769413e 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -50,6 +50,13 @@ policy: horizon: ${horizon} n_action_steps: ${n_action_steps} + input_shapes: + # TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env? + observation.image: [3, 96, 96] + observation.state: ["${policy.state_dim}"] + output_shapes: + action: ["${policy.action_dim}"] + # Normalization / Unnormalization normalize_input_modes: observation.image: mean_std diff --git a/tests/test_policies.py b/tests/test_policies.py index 37401598..d0d53e7b 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config - -from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env # TODO(aliberts): refactor using lerobot/__init__.py variables @@ -93,3 +93,111 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) + + # Test load state_dict + if policy_name != "tdmpc": + # TODO(rcadene, alexander-soar): make it work for tdmpc + # TODO(rcadene, alexander-soar): how to remove need for dataset_stats? + new_policy = make_policy(cfg, dataset_stats=dataset.stats) + new_policy.load_state_dict(policy.state_dict()) + new_policy.update(batch, step=0) + + +@pytest.mark.parametrize( + "insert_temporal_dim", + [ + False, + True, + ], +) +def test_normalize(insert_temporal_dim): + # TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization + + input_shapes = { + "observation.image": [3, 96, 96], + "observation.state": [10], + } + output_shapes = { + "action": [5], + } + + normalize_input_modes = { + "observation.image": "mean_std", + "observation.state": "min_max", + } + unnormalize_output_modes = { + "action": "min_max", + } + + dataset_stats = { + "observation.image": { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + "min": torch.randn(3, 1, 1), + "max": torch.randn(3, 1, 1), + }, + "observation.state": { + "mean": torch.randn(10), + "std": torch.randn(10), + "min": torch.randn(10), + "max": torch.randn(10), + }, + "action": { + "mean": torch.randn(5), + "std": torch.randn(5), + "min": torch.randn(5), + "max": torch.randn(5), + }, + } + + bsize = 2 + input_batch = { + "observation.image": torch.randn(bsize, 3, 96, 96), + "observation.state": torch.randn(bsize, 10), + } + output_batch = { + "action": torch.randn(bsize, 5), + } + + if insert_temporal_dim: + tdim = 4 + + for key in input_batch: + # [2,3,96,96] -> [2,tdim,3,96,96] + input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) + + for key in output_batch: + output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) + + # test without stats + normalize = Normalize(input_shapes, normalize_input_modes, stats=None) + normalize(input_batch) + + # test with stats + normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats) + normalize(input_batch) + + # test loading pretrained models + new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None) + new_normalize.load_state_dict(normalize.state_dict()) + new_normalize(input_batch) + + # test wihtout stats + unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) + unnormalize(output_batch) + + # test with stats + unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats) + unnormalize(output_batch) + + # test loading pretrained models + new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) + new_unnormalize.load_state_dict(unnormalize.state_dict()) + unnormalize(output_batch) + + +if __name__ == "__main__": + test_policy( + *("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]) + ) + # test_policy(insert_temporal_dim=True)