This commit is contained in:
Thomas Wolf
2024-06-19 10:07:41 +02:00
parent 1cd7ca71a1
commit 33166e1d43
2 changed files with 21 additions and 38 deletions

View File

@@ -331,24 +331,14 @@ def test_normalize(insert_temporal_dim):
).all()
assert torch.isclose(
normalize.buffer_action_test_std_cap.std[0],
dataset_stats["action_test_std_cap"]["std"][0],
normalize.buffer_action_test_std_cap.std,
dataset_stats["action_test_std_cap"]["std"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0],
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1],
torch.ones(1) * std_epsilon,
normalize.buffer_action_test_min_max_cap.max - normalize.buffer_action_test_min_max_cap.min,
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
rtol=0.1,
atol=1e-7,
).all()
@@ -496,24 +486,14 @@ def test_normalize(insert_temporal_dim):
).all()
assert torch.isclose(
unnormalize.buffer_action_test_std_cap.std[0],
dataset_stats["action_test_std_cap"]["std"][0],
unnormalize.buffer_action_test_std_cap.std,
dataset_stats["action_test_std_cap"]["std"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0],
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1],
torch.ones(1) * std_epsilon,
unnormalize.buffer_action_test_min_max_cap.max - unnormalize.buffer_action_test_min_max_cap.min,
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
rtol=0.1,
atol=1e-7,
).all()