From 45f351c618c7b56e9c6c11ea414b3b048b8c250a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 26 Apr 2024 11:18:39 +0100 Subject: [PATCH] Make sure targets are normalized too (#106) --- .../common/policies/act/configuration_act.py | 4 +- lerobot/common/policies/act/modeling_act.py | 8 +- .../diffusion/configuration_diffusion.py | 8 +- .../policies/diffusion/modeling_diffusion.py | 8 +- lerobot/common/policies/normalize.py | 171 ++++++++++-------- lerobot/configs/policy/act.yaml | 4 +- lerobot/configs/policy/diffusion.yaml | 4 +- lerobot/scripts/train.py | 1 - 8 files changed, 116 insertions(+), 92 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 82280b2c..c8c85c04 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig: ) # Normalization / Unnormalization - normalize_input_modes: dict[str, str] = field( + input_normalization_modes: dict[str, str] = field( default_factory=lambda: { "observation.image": "mean_std", "observation.state": "mean_std", } ) - unnormalize_output_modes: dict[str, str] = field( + output_normalization_modes: dict[str, str] = field( default_factory=lambda: { "action": "mean_std", } diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index c2dd5bf7..4501c6cc 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -72,8 +72,11 @@ class ActionChunkingTransformerPolicy(nn.Module): if cfg is None: cfg = ActionChunkingTransformerConfig() self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) - self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) + self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize( + cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + ) # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). @@ -216,6 +219,7 @@ class ActionChunkingTransformerPolicy(nn.Module): self.train() batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) loss_dict = self.forward(batch) # TODO(rcadene): self.unnormalize_outputs(out_dict) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 9a725a56..a5c739c4 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -83,17 +83,13 @@ class DiffusionConfig: ) # Normalization / Unnormalization - normalize_input_modes: dict[str, str] = field( + input_normalization_modes: dict[str, str] = field( default_factory=lambda: { "observation.image": "mean_std", "observation.state": "min_max", } ) - unnormalize_output_modes: dict[str, str] = field( - default_factory=lambda: { - "action": "min_max", - } - ) + output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) # Architecture / modeling. # Vision backbone. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 7a639375..1dd545d3 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module): if cfg is None: cfg = DiffusionConfig() self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) - self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) + self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize( + cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + ) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module): self.diffusion.train() batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) loss = self.forward(batch)["loss"] loss.backward() diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 4d230b16..df615a21 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -1,27 +1,21 @@ import torch -from torch import nn +from torch import Tensor, nn -def create_stats_buffers(shapes, modes, stats=None): +def create_stats_buffers( + shapes: dict[str, list[int]], + modes: dict[str, str], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> dict[str, dict[str, nn.ParameterDict]]: """ - Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics. + Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max + statistics. - Parameters: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]). - These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height - and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among: - - "mean_std": substract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values - (e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time, - these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be - be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since - they are already in the policy state_dict. + Args: (see Normalize and Unnormalize) Returns: - dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to - `requires_grad=False`, suitable to not be updated during backpropagation. + dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing + `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. """ stats_buffers = {} @@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None): class Normalize(nn.Module): - """ - Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training. + """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" - Parameters: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]). - These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height - and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among: - - "mean_std": substract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values - (e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time, - these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be - be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since - they are already in the policy state_dict. - """ - - def __init__(self, shapes, modes, stats=None): + def __init__( + self, + shapes: dict[str, list[int]], + modes: dict[str, str], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") + and values are dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_dict. + """ super().__init__() self.shapes = shapes self.modes = modes @@ -104,29 +106,33 @@ class Normalize(nn.Module): # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad - def forward(self, batch): + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf( - mean - ).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." - assert not torch.isinf( - std - ).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf(mean).any(), ( + "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) + assert not torch.isinf(std).any(), ( + "`std` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) batch[key] = (batch[key] - mean) / (std + 1e-8) elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf( - min - ).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." - assert not torch.isinf( - max - ).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf(min).any(), ( + "`min` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) + assert not torch.isinf(max).any(), ( + "`max` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) # normalize to [0,1] batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] @@ -138,23 +144,34 @@ class Normalize(nn.Module): class Unnormalize(nn.Module): """ - Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment. - - Parameters: - shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]). - These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height - and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among: - - "mean_std": multiply by standard deviation and add mean - - "min_max": go from [-1, 1] range to original range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values - (e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time, - these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be - be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since - they are already in the policy state_dict. + Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their + original range used by the environment. """ - def __init__(self, shapes, modes, stats=None): + def __init__( + self, + shapes: dict[str, list[int]], + modes: dict[str, str], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") + and values are dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_dict. + """ super().__init__() self.shapes = shapes self.modes = modes @@ -166,29 +183,33 @@ class Unnormalize(nn.Module): # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad - def forward(self, batch): + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf( - mean - ).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." - assert not torch.isinf( - std - ).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf(mean).any(), ( + "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) + assert not torch.isinf(std).any(), ( + "`std` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) batch[key] = batch[key] * std + mean elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf( - min - ).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." - assert not torch.isinf( - max - ).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." + assert not torch.isinf(min).any(), ( + "`min` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) + assert not torch.isinf(max).any(), ( + "`max` is infinity. You forgot to initialize with `stats` as argument, or called " + "`policy.load_state_dict`." + ) batch[key] = (batch[key] + 1) / 2 batch[key] = batch[key] * (max - min) + min else: diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index cfde3b91..d4ad195c 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -36,10 +36,10 @@ policy: action: ["${env.action_dim}"] # Normalization / Unnormalization - normalize_input_modes: + input_normalization_modes: observation.images.top: mean_std observation.state: mean_std - unnormalize_output_modes: + output_normalization_modes: action: mean_std # Architecture. diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f844534e..999d62ea 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -50,10 +50,10 @@ policy: action: ["${env.action_dim}"] # Normalization / Unnormalization - normalize_input_modes: + input_normalization_modes: observation.image: mean_std observation.state: min_max - unnormalize_output_modes: + output_normalization_modes: action: min_max # Architecture / modeling. diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c849cce8..0447c84e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -258,7 +258,6 @@ def train(cfg: dict, out_dir=None, job_name=None): policy, video_dir=Path(out_dir) / "eval", max_episodes_rendered=4, - transform=offline_dataset.transform, seed=cfg.seed, ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)