update unnormalize

This commit is contained in:
Thomas Wolf
2024-06-18 11:55:15 +02:00
parent cd9ace20b6
commit c9150c361b
2 changed files with 152 additions and 11 deletions

View File

@@ -285,10 +285,10 @@ def test_normalize(insert_temporal_dim):
}
output_batch = {
"action": torch.randn(bsize, 5),
"action_test_std": torch.ones(bsize, 1) * 2.5,
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
"action_test_std": torch.ones(bsize, 1) * 1.5,
"action_test_min_max": torch.ones(bsize, 1) * 1.5,
"action_test_std_cap": torch.ones(bsize, 2) * 1.5,
"action_test_min_max_cap": torch.ones(bsize, 2) * 1.5,
}
if insert_temporal_dim:
@@ -471,8 +471,143 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
unnormalize(output_batch)
unnormalize = Unnormalize(
output_shapes, unnormalize_output_modes, stats=dataset_stats, std_epsilon=std_epsilon
)
# check that the stats are correctly set including the min capping
assert torch.isclose(
unnormalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
unnormalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max.min,
dataset_stats["action_test_min_max"]["min"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max.max,
dataset_stats["action_test_min_max"]["max"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_std_cap.std[0],
dataset_stats["action_test_std_cap"]["std"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0],
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1],
torch.ones(1) * std_epsilon,
rtol=0.1,
atol=1e-7,
).all()
unnormalize_output = unnormalize(output_batch)
# check that the unnormalization is correct
assert torch.isclose(
unnormalize_output["action_test_std"],
output_batch["action_test_std"] * dataset_stats["action_test_std"]["std"]
+ dataset_stats["action_test_std"]["mean"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_min_max"],
(output_batch["action_test_min_max"] + 1)
/ 2
* (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
+ dataset_stats["action_test_min_max"]["min"],
rtol=0.1,
atol=1e-7,
).all()
if insert_temporal_dim:
assert torch.isclose(
unnormalize_output["action_test_std_cap"][0, 0, 0],
output_batch["action_test_std_cap"][0, 0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
+ dataset_stats["action_test_std_cap"]["mean"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_std_cap"][0, 0, 1],
output_batch["action_test_std_cap"][0, 0, 1] * std_epsilon
+ dataset_stats["action_test_std_cap"]["mean"][1],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_min_max_cap"][0, 0, 0],
(output_batch["action_test_min_max_cap"][0, 0, 0] + 1)
/ 2
* (
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
+ dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_min_max_cap"][0, 0, 1],
(output_batch["action_test_min_max_cap"][0, 0, 1] + 1) / 2 * std_epsilon
+ dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
else:
assert torch.isclose(
unnormalize_output["action_test_std_cap"][0, 0],
output_batch["action_test_std_cap"][0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
+ dataset_stats["action_test_std_cap"]["mean"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_std_cap"][0, 1],
output_batch["action_test_std_cap"][0, 1] * std_epsilon
+ dataset_stats["action_test_std_cap"]["mean"][1],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_min_max_cap"][0, 0],
(output_batch["action_test_min_max_cap"][0, 0] + 1)
/ 2
* (
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
+ dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize_output["action_test_min_max_cap"][0, 1],
(output_batch["action_test_min_max_cap"][0, 1] + 1) / 2 * std_epsilon
+ dataset_stats["action_test_min_max_cap"]["min"][1],
rtol=0.1,
atol=1e-7,
).all()
# test loading pretrained models
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)