From ea89b29fe5a293be0796cd0a4df3309a67d18207 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 25 Apr 2025 18:10:59 +0200 Subject: [PATCH] checkout normalize.py to prev commit --- lerobot/common/policies/normalize.py | 57 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 38fb05c7..845e139a 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -79,28 +79,48 @@ def create_stats_buffers( ) # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats: - if isinstance(stats[key]["mean"], np.ndarray): - if norm_mode is NormalizationMode.MEAN_STD: + if stats and key in stats: + # NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats + # Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" not in stats[key] or "std" not in stats[key]: + raise ValueError( + f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization" + ) + + if isinstance(stats[key]["mean"], np.ndarray): buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: + else: + type_ = type(stats[key]["mean"]) + raise ValueError( + f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead." + ) + + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" not in stats[key] or "max" not in stats[key]: + raise ValueError( + f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization" + ) + + if isinstance(stats[key]["min"], np.ndarray): + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + elif isinstance(stats[key]["min"], torch.Tensor): buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") + else: + type_ = type(stats[key]["min"]) + raise ValueError( + f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead." + ) stats_buffers[key] = buffer return stats_buffers @@ -148,11 +168,14 @@ class Normalize(nn.Module): for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) + # TODO(rcadene): should we remove torch.no_grad? + # @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: # FIXME(aliberts, rcadene): This might lead to silent fail! + # NOTE: (azouitine) This continues help us for instantiation SACPolicy continue norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) @@ -220,6 +243,8 @@ class Unnormalize(nn.Module): for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) + # TODO(rcadene): should we remove torch.no_grad? + # @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items():