From c9150c361b9191f3e7141fb4b2b5873b6efd2359 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Tue, 18 Jun 2024 11:55:15 +0200 Subject: [PATCH] update unnormalize --- lerobot/common/policies/normalize.py | 16 ++- tests/test_policies.py | 147 +++++++++++++++++++++++++-- 2 files changed, 152 insertions(+), 11 deletions(-) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index e51def2d..a711dcaa 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -176,6 +176,7 @@ class Unnormalize(nn.Module): shapes: dict[str, list[int]], modes: dict[str, str], stats: dict[str, dict[str, Tensor]] | None = None, + std_epsilon: float = 1e-5, ): """ Args: @@ -194,19 +195,24 @@ class Unnormalize(nn.Module): not provided, as expected for finetuning or evaluation, the default buffers should to be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since they are already in the policy state_dict. + std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by + zero in the Normalize step. We use the same value for unnormalization here to have a consistent + behavior. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference + between min and max) is at least `std_epsilon`. """ super().__init__() self.shapes = shapes self.modes = modes self.stats = stats # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(shapes, modes, stats) + stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + output_batch = {} for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) @@ -215,14 +221,14 @@ class Unnormalize(nn.Module): std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean + output_batch[key] = batch[key] * std + mean elif mode == "min_max": min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min + output_batch[key] = (batch[key] + 1) / 2 + output_batch[key] = output_batch[key] * (max - min) + min else: raise ValueError(mode) - return batch + return output_batch diff --git a/tests/test_policies.py b/tests/test_policies.py index 963b4a3a..310e530a 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -285,10 +285,10 @@ def test_normalize(insert_temporal_dim): } output_batch = { "action": torch.randn(bsize, 5), - "action_test_std": torch.ones(bsize, 1) * 2.5, - "action_test_min_max": torch.ones(bsize, 1) * 2.5, - "action_test_std_cap": torch.ones(bsize, 2) * 2.5, - "action_test_min_max_cap": torch.ones(bsize, 2) * 2.5, + "action_test_std": torch.ones(bsize, 1) * 1.5, + "action_test_min_max": torch.ones(bsize, 1) * 1.5, + "action_test_std_cap": torch.ones(bsize, 2) * 1.5, + "action_test_min_max_cap": torch.ones(bsize, 2) * 1.5, } if insert_temporal_dim: @@ -471,8 +471,143 @@ def test_normalize(insert_temporal_dim): unnormalize(output_batch) # test with stats - unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats) - unnormalize(output_batch) + unnormalize = Unnormalize( + output_shapes, unnormalize_output_modes, stats=dataset_stats, std_epsilon=std_epsilon + ) + + # check that the stats are correctly set including the min capping + assert torch.isclose( + unnormalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_min_max.min, + dataset_stats["action_test_min_max"]["min"], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_min_max.max, + dataset_stats["action_test_min_max"]["max"], + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + unnormalize.buffer_action_test_std_cap.std[0], + dataset_stats["action_test_std_cap"]["std"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7 + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0], + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1], + torch.ones(1) * std_epsilon, + rtol=0.1, + atol=1e-7, + ).all() + + unnormalize_output = unnormalize(output_batch) + + # check that the unnormalization is correct + assert torch.isclose( + unnormalize_output["action_test_std"], + output_batch["action_test_std"] * dataset_stats["action_test_std"]["std"] + + dataset_stats["action_test_std"]["mean"], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize_output["action_test_min_max"], + (output_batch["action_test_min_max"] + 1) + / 2 + * (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"]) + + dataset_stats["action_test_min_max"]["min"], + rtol=0.1, + atol=1e-7, + ).all() + + if insert_temporal_dim: + assert torch.isclose( + unnormalize_output["action_test_std_cap"][0, 0, 0], + output_batch["action_test_std_cap"][0, 0, 0] * dataset_stats["action_test_std_cap"]["std"][0] + + dataset_stats["action_test_std_cap"]["mean"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize_output["action_test_std_cap"][0, 0, 1], + output_batch["action_test_std_cap"][0, 0, 1] * std_epsilon + + dataset_stats["action_test_std_cap"]["mean"][1], + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + unnormalize_output["action_test_min_max_cap"][0, 0, 0], + (output_batch["action_test_min_max_cap"][0, 0, 0] + 1) + / 2 + * ( + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + + dataset_stats["action_test_min_max_cap"]["min"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize_output["action_test_min_max_cap"][0, 0, 1], + (output_batch["action_test_min_max_cap"][0, 0, 1] + 1) / 2 * std_epsilon + + dataset_stats["action_test_min_max_cap"]["min"][0], + rtol=0.1, + atol=1e-7, + ).all() + else: + assert torch.isclose( + unnormalize_output["action_test_std_cap"][0, 0], + output_batch["action_test_std_cap"][0, 0] * dataset_stats["action_test_std_cap"]["std"][0] + + dataset_stats["action_test_std_cap"]["mean"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize_output["action_test_std_cap"][0, 1], + output_batch["action_test_std_cap"][0, 1] * std_epsilon + + dataset_stats["action_test_std_cap"]["mean"][1], + rtol=0.1, + atol=1e-7, + ).all() + + assert torch.isclose( + unnormalize_output["action_test_min_max_cap"][0, 0], + (output_batch["action_test_min_max_cap"][0, 0] + 1) + / 2 + * ( + dataset_stats["action_test_min_max_cap"]["max"][0] + - dataset_stats["action_test_min_max_cap"]["min"][0] + ) + + dataset_stats["action_test_min_max_cap"]["min"][0], + rtol=0.1, + atol=1e-7, + ).all() + assert torch.isclose( + unnormalize_output["action_test_min_max_cap"][0, 1], + (output_batch["action_test_min_max_cap"][0, 1] + 1) / 2 * std_epsilon + + dataset_stats["action_test_min_max_cap"]["min"][1], + rtol=0.1, + atol=1e-7, + ).all() # test loading pretrained models new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)