[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -31,7 +31,11 @@ from lerobot.common.datasets.compute_stats import (
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
return (
np.ones((3, 32, 32), dtype=dtype)
if channel_first
else np.ones((32, 32, 3), dtype=dtype)
)
@pytest.fixture
@@ -61,7 +65,10 @@ def test_sample_indices():
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
@patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
@@ -74,9 +81,20 @@ def test_sample_images(mock_load):
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
assert (
"min" in stats
and "max" in stats
and "mean" in stats
and "std" in stats
and "count" in stats
)
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
assert (
stats["min"].shape
== stats["max"].shape
== stats["mean"].shape
== stats["std"].shape
)
def test_get_feature_stats_axis_0_keepdims(sample_array):
@@ -145,7 +163,8 @@ def test_compute_episode_stats():
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
):
stats = compute_episode_stats(episode_data, features)
@@ -233,7 +252,13 @@ def test_aggregate_stats():
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"observation.state": {
"min": 1,
"max": 10,
"mean": 5.5,
"std": 2.87,
"count": 10,
},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
@@ -244,7 +269,13 @@ def test_aggregate_stats():
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"observation.state": {
"min": 2,
"max": 15,
"mean": 8.5,
"std": 3.42,
"count": 15,
},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
@@ -284,28 +315,47 @@ def test_aggregate_stats():
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
results[fkey]["min"], expected_agg_stats[fkey]["min"]
)
np.testing.assert_allclose(
results[fkey]["max"], expected_agg_stats[fkey]["max"]
)
np.testing.assert_allclose(
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
)
np.testing.assert_allclose(
results[fkey]["std"],
expected_agg_stats[fkey]["std"],
atol=1e-04,
rtol=1e-04,
)
np.testing.assert_allclose(
results[fkey]["count"], expected_agg_stats[fkey]["count"]
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])