forked from tangger/lerobot
update
This commit is contained in:
@@ -204,17 +204,33 @@ def test_normalize(insert_temporal_dim):
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
"action_test_std": [1],
|
||||
"action_test_min_max": [1],
|
||||
"action_test_std_cap": [2],
|
||||
"action_test_min_max_cap": [2],
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
"action_test_std": [1],
|
||||
"action_test_min_max": [1],
|
||||
"action_test_std_cap": [2],
|
||||
"action_test_min_max_cap": [2],
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"action_test_std": "mean_std",
|
||||
"action_test_min_max": "min_max",
|
||||
"action_test_std_cap": "mean_std",
|
||||
"action_test_min_max_cap": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
"action_test_std": "mean_std",
|
||||
"action_test_min_max": "min_max",
|
||||
"action_test_std_cap": "mean_std",
|
||||
"action_test_min_max_cap": "min_max",
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
@@ -236,15 +252,43 @@ def test_normalize(insert_temporal_dim):
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
"action_test_std": {
|
||||
"mean": torch.ones(1) * 2,
|
||||
"std": torch.ones(1) * 0.2,
|
||||
},
|
||||
"action_test_min_max": {
|
||||
"min": torch.ones(1) * 1,
|
||||
"max": torch.ones(1) * 3,
|
||||
},
|
||||
"action_test_std_cap": {
|
||||
"mean": torch.ones(2) * 2,
|
||||
"std": torch.ones(2) * 0.2,
|
||||
},
|
||||
"action_test_min_max_cap": {
|
||||
"min": torch.ones(2) * 1.0,
|
||||
"max": torch.ones(2) * 3.0,
|
||||
},
|
||||
}
|
||||
|
||||
# Set some values to 0 to test the case where the std is 0 - for max we set it to min
|
||||
dataset_stats["action_test_std_cap"]["std"][1] = 0.0
|
||||
dataset_stats["action_test_min_max_cap"]["max"][1] = dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
"action_test_std": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
|
||||
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
"action_test_std": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
|
||||
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
@@ -263,8 +307,158 @@ def test_normalize(insert_temporal_dim):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
std_epsilon = 1e-2
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon)
|
||||
|
||||
# check that the stats are correctly set including the min capping
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max.min,
|
||||
dataset_stats["action_test_min_max"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max.max,
|
||||
dataset_stats["action_test_min_max"]["max"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std_cap.std[0],
|
||||
dataset_stats["action_test_std_cap"]["std"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0],
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1],
|
||||
torch.ones(1) * std_epsilon,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
normalized_output = normalize(input_batch)
|
||||
|
||||
# check that the normalization is correct
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std"],
|
||||
(input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"])
|
||||
/ dataset_stats["action_test_std"]["std"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max"],
|
||||
(input_batch["action_test_min_max"] - dataset_stats["action_test_min_max"]["min"])
|
||||
/ (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
if insert_temporal_dim:
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0, 0],
|
||||
(input_batch["action_test_std_cap"][0, 0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
|
||||
/ dataset_stats["action_test_std_cap"]["std"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0, 1],
|
||||
(input_batch["action_test_std_cap"][0, 0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
|
||||
/ std_epsilon,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0, 0],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0, 0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
/ (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0, 1],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0, 1]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
)
|
||||
/ std_epsilon
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
else:
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0],
|
||||
(input_batch["action_test_std_cap"][0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
|
||||
/ dataset_stats["action_test_std_cap"]["std"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 1],
|
||||
(input_batch["action_test_std_cap"][0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
|
||||
/ std_epsilon,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
/ (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 1],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 1]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
)
|
||||
/ std_epsilon
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
|
||||
Reference in New Issue
Block a user