Format file

This commit is contained in:
AdilZouitine
2025-05-07 10:26:18 +02:00
parent adbf8bb85e
commit b36ec31fea
13 changed files with 43 additions and 169 deletions

View File

@@ -61,10 +61,7 @@ def test_sample_indices():
assert len(indices) == estimate_num_samples(10)
@patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
@@ -148,8 +145,7 @@ def test_compute_episode_stats():
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
):
stats = compute_episode_stats(episode_data, features)
@@ -237,13 +233,7 @@ def test_aggregate_stats():
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {
"min": 1,
"max": 10,
"mean": 5.5,
"std": 2.87,
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
@@ -254,13 +244,7 @@ def test_aggregate_stats():
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {
"min": 2,
"max": 15,
"mean": 8.5,
"std": 3.42,
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
@@ -322,9 +306,6 @@ def test_aggregate_stats():
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"],
expected_agg_stats[fkey]["std"],
atol=1e-04,
rtol=1e-04,
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])