From fa72aed5b637015236782c63e38e0ecde075661e Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 15 May 2025 18:42:59 +0200 Subject: [PATCH] Remove numpy array support --- lerobot/common/policies/normalize.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 9734bcab..05aa320e 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -254,7 +254,7 @@ class Unnormalize(nn.Module): return batch -# TODO: We should replace all normalization on the policies with register_buffer normalization +# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization # and remove the `Normalize` and `Unnormalize` classes. def _initialize_stats_buffers( module: nn.Module, @@ -287,10 +287,11 @@ def _initialize_stats_buffers( if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: mean_data = stats[key]["mean"] std_data = stats[key]["std"] - if isinstance(mean_data, np.ndarray): - mean = torch.from_numpy(mean_data).to(dtype=torch.float32) - std = torch.from_numpy(std_data).to(dtype=torch.float32) - elif isinstance(mean_data, torch.Tensor): + if isinstance(mean_data, torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. mean = mean_data.clone().to(dtype=torch.float32) std = std_data.clone().to(dtype=torch.float32) else: @@ -307,10 +308,7 @@ def _initialize_stats_buffers( if stats and key in stats and "min" in stats[key] and "max" in stats[key]: min_data = stats[key]["min"] max_data = stats[key]["max"] - if isinstance(min_data, np.ndarray): - min_val = torch.from_numpy(min_data).to(dtype=torch.float32) - max_val = torch.from_numpy(max_data).to(dtype=torch.float32) - elif isinstance(min_data, torch.Tensor): + if isinstance(min_data, torch.Tensor): min_val = min_data.clone().to(dtype=torch.float32) max_val = max_data.clone().to(dtype=torch.float32) else: