[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -31,11 +31,7 @@ from lerobot.common.datasets.compute_stats import (
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return (
|
||||
np.ones((3, 32, 32), dtype=dtype)
|
||||
if channel_first
|
||||
else np.ones((32, 32, 3), dtype=dtype)
|
||||
)
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -81,20 +77,9 @@ def test_sample_images(mock_load):
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
assert (
|
||||
"min" in stats
|
||||
and "max" in stats
|
||||
and "mean" in stats
|
||||
and "std" in stats
|
||||
and "count" in stats
|
||||
)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert (
|
||||
stats["min"].shape
|
||||
== stats["max"].shape
|
||||
== stats["mean"].shape
|
||||
== stats["std"].shape
|
||||
)
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
@@ -315,47 +300,31 @@ def test_aggregate_stats():
|
||||
for ep_stats in all_stats:
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
# cast to numpy
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(
|
||||
stats[k], dtype=np.int64 if k == "count" else np.float32
|
||||
)
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(
|
||||
3, 1, 1
|
||||
) # for normalization on image channels
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
results = aggregate_stats(all_stats)
|
||||
|
||||
for fkey in expected_agg_stats:
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["min"], expected_agg_stats[fkey]["min"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["max"], expected_agg_stats[fkey]["max"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"],
|
||||
expected_agg_stats[fkey]["std"],
|
||||
atol=1e-04,
|
||||
rtol=1e-04,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["count"], expected_agg_stats[fkey]["count"]
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
||||
Reference in New Issue
Block a user