diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 9b055f7e..e51def2d 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -21,6 +21,7 @@ def create_stats_buffers( shapes: dict[str, list[int]], modes: dict[str, str], stats: dict[str, dict[str, Tensor]] | None = None, + std_epsilon: float = 1e-5, ) -> dict[str, dict[str, nn.ParameterDict]]: """ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max @@ -78,10 +79,14 @@ def create_stats_buffers( # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if mode == "mean_std": buffer["mean"].data = stats[key]["mean"].clone() - buffer["std"].data = stats[key]["std"].clone() + buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon) elif mode == "min_max": buffer["min"].data = stats[key]["min"].clone() buffer["max"].data = stats[key]["max"].clone() + epsilon = (std_epsilon - (stats[key]["max"] - stats[key]["min"]).abs()).clamp_min( + 0 + ) # To add to have at least std_epsilon between min and max + buffer["max"].data += epsilon stats_buffers[key] = buffer return stats_buffers @@ -102,6 +107,7 @@ class Normalize(nn.Module): shapes: dict[str, list[int]], modes: dict[str, str], stats: dict[str, dict[str, Tensor]] | None = None, + std_epsilon: float = 1e-5, ): """ Args: @@ -120,18 +126,22 @@ class Normalize(nn.Module): not provided, as expected for finetuning or evaluation, the default buffers should to be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since they are already in the policy state_dict. + std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by + zero. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference + between min and max) is at least `std_epsilon`. """ super().__init__() self.shapes = shapes self.modes = modes self.stats = stats - stats_buffers = create_stats_buffers(shapes, modes, stats) + stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + output_batch = {} for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) @@ -140,19 +150,19 @@ class Normalize(nn.Module): std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) + output_batch[key] = (batch[key] - mean) / std elif mode == "min_max": min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) + output_batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 + output_batch[key] = output_batch[key] * 2 - 1 else: raise ValueError(mode) - return batch + return output_batch class Unnormalize(nn.Module): diff --git a/tests/test_policies.py b/tests/test_policies.py index c099bef0..963b4a3a 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -204,17 +204,33 @@ def test_normalize(insert_temporal_dim): input_shapes = { "observation.image": [3, 96, 96], "observation.state": [10], + "action_test_std": [1], + "action_test_min_max": [1], + "action_test_std_cap": [2], + "action_test_min_max_cap": [2], } output_shapes = { "action": [5], + "action_test_std": [1], + "action_test_min_max": [1], + "action_test_std_cap": [2], + "action_test_min_max_cap": [2], } normalize_input_modes = { "observation.image": "mean_std", "observation.state": "min_max", + "action_test_std": "mean_std", + "action_test_min_max": "min_max", + "action_test_std_cap": "mean_std", + "action_test_min_max_cap": "min_max", } unnormalize_output_modes = { "action": "min_max", + "action_test_std": "mean_std", + "action_test_min_max": "min_max", + "action_test_std_cap": "mean_std", + "action_test_min_max_cap": "min_max", } dataset_stats = { @@ -236,15 +252,43 @@ def test_normalize(insert_temporal_dim): "min": torch.randn(5), "max": torch.randn(5), }, + "action_test_std": { + "mean": torch.ones(1) * 2, + "std": torch.ones(1) * 0.2, + }, + "action_test_min_max": { + "min": torch.ones(1) * 1, + "max": torch.ones(1) * 3, + }, + "action_test_std_cap": { + "mean": torch.ones(2) * 2, + "std": torch.ones(2) * 0.2, + }, + "action_test_min_max_cap": { + "min": torch.ones(2) * 1.0, + "max": torch.ones(2) * 3.0, + }, } + # Set some values to 0 to test the case where the std is 0 - for max we set it to min + dataset_stats["action_test_std_cap"]["std"][1] = 0.0 + dataset_stats["action_test_min_max_cap"]["max"][1] = dataset_stats["action_test_min_max_cap"]["min"][1] + bsize = 2 input_batch = { "observation.image": torch.randn(bsize, 3, 96, 96), "observation.state": torch.randn(bsize, 10), + "action_test_std": torch.ones(bsize, 1) * 2.5, + "action_test_min_max": torch.ones(bsize, 1) * 2.5, + "action_test_std_cap": torch.ones(bsize, 2) * 2.5, + "action_test_min_max_cap": torch.ones(bsize, 2) * 2.5, } output_batch = { "action": torch.randn(bsize, 5), + "action_test_std": torch.ones(bsize, 1) * 2.5, + "action_test_min_max": torch.ones(bsize, 1) * 2.5, + "action_test_std_cap": torch.ones(bsize, 2) * 2.5, + "action_test_min_max_cap": torch.ones(bsize, 2) * 2.5, } if insert_temporal_dim: @@ -263,8 +307,158 @@ def test_normalize(insert_temporal_dim): normalize(input_batch) # test with stats - normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats) - normalize(input_batch) + std_epsilon = 1e-2 + normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon) + + # check that the stats are correctly set including the min capping + assert torch.isclose( + normalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + normalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + normalize.buffer_action_test_min_max.min, + dataset_stats["action_test_min_max"]["min"], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalize.buffer_action_test_min_max.max, + dataset_stats["action_test_min_max"]["max"], + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + normalize.buffer_action_test_std_cap.std[0], + dataset_stats["action_test_std_cap"]["std"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0], + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1], + torch.ones(1) * std_epsilon, + rtol=0.1, + atol=1e-7, + ).all() + + normalized_output = normalize(input_batch) + + # check that the normalization is correct + assert torch.isclose( + normalized_output["action_test_std"], + (input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"]) + / dataset_stats["action_test_std"]["std"], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalized_output["action_test_min_max"], + (input_batch["action_test_min_max"] - dataset_stats["action_test_min_max"]["min"]) + / (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"]) + * 2 + - 1, + rtol=0.1, + atol=1e-7, + ).all() + + if insert_temporal_dim: + assert torch.isclose( + normalized_output["action_test_std_cap"][0, 0, 0], + (input_batch["action_test_std_cap"][0, 0, 0] - dataset_stats["action_test_std_cap"]["mean"][0]) + / dataset_stats["action_test_std_cap"]["std"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalized_output["action_test_std_cap"][0, 0, 1], + (input_batch["action_test_std_cap"][0, 0, 1] - dataset_stats["action_test_std_cap"]["mean"][1]) + / std_epsilon, + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + normalized_output["action_test_min_max_cap"][0, 0, 0], + ( + input_batch["action_test_min_max_cap"][0, 0, 0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + / ( + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + * 2 + - 1, + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalized_output["action_test_min_max_cap"][0, 0, 1], + ( + input_batch["action_test_min_max_cap"][0, 0, 1] + - dataset_stats["action_test_min_max_cap"]["min"][1] + ) + / std_epsilon + * 2 + - 1, + rtol=0.1, + atol=1e-7, + ).all() + else: + assert torch.isclose( + normalized_output["action_test_std_cap"][0, 0], + (input_batch["action_test_std_cap"][0, 0] - dataset_stats["action_test_std_cap"]["mean"][0]) + / dataset_stats["action_test_std_cap"]["std"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalized_output["action_test_std_cap"][0, 1], + (input_batch["action_test_std_cap"][0, 1] - dataset_stats["action_test_std_cap"]["mean"][1]) + / std_epsilon, + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + normalized_output["action_test_min_max_cap"][0, 0], + ( + input_batch["action_test_min_max_cap"][0, 0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + / ( + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + * 2 + - 1, + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + normalized_output["action_test_min_max_cap"][0, 1], + ( + input_batch["action_test_min_max_cap"][0, 1] + - dataset_stats["action_test_min_max_cap"]["min"][1] + ) + / std_epsilon + * 2 + - 1, + rtol=0.1, + atol=1e-7, + ).all() # test loading pretrained models new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)