Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots
This commit is contained in:
311
tests/datasets/test_compute_stats.py
Normal file
311
tests/datasets/test_compute_stats.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
_assert_type_and_shape,
|
||||
aggregate_feature_stats,
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
|
||||
def test_estimate_num_samples():
|
||||
assert estimate_num_samples(1) == 1
|
||||
assert estimate_num_samples(10) == 10
|
||||
assert estimate_num_samples(100) == 100
|
||||
assert estimate_num_samples(200) == 100
|
||||
assert estimate_num_samples(1000) == 177
|
||||
assert estimate_num_samples(2000) == 299
|
||||
assert estimate_num_samples(5000) == 594
|
||||
assert estimate_num_samples(10_000) == 1000
|
||||
assert estimate_num_samples(20_000) == 1681
|
||||
assert estimate_num_samples(50_000) == 3343
|
||||
assert estimate_num_samples(500_000) == 10_000
|
||||
|
||||
|
||||
def test_sample_indices():
|
||||
indices = sample_indices(10)
|
||||
assert len(indices) > 0
|
||||
assert indices[0] == 0
|
||||
assert indices[-1] == 9
|
||||
assert len(indices) == estimate_num_samples(10)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
assert isinstance(images, np.ndarray)
|
||||
assert images.shape[1:] == (3, 32, 32)
|
||||
assert images.dtype == np.uint8
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
"max": np.array([[7, 8, 9]]),
|
||||
"mean": np.array([[4.0, 5.0, 6.0]]),
|
||||
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_1(sample_array):
|
||||
expected = {
|
||||
"min": np.array([1, 4, 7]),
|
||||
"max": np.array([3, 6, 9]),
|
||||
"mean": np.array([2.0, 5.0, 8.0]),
|
||||
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_no_axis(sample_array):
|
||||
expected = {
|
||||
"min": np.array(1),
|
||||
"max": np.array(9),
|
||||
"mean": np.array(5.0),
|
||||
"std": np.array(2.5819889),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_empty_array():
|
||||
array = np.array([])
|
||||
with pytest.raises(ValueError):
|
||||
get_feature_stats(array, axis=(0,), keepdims=True)
|
||||
|
||||
|
||||
def test_get_feature_stats_single_value():
|
||||
array = np.array([[1337]])
|
||||
result = get_feature_stats(array, axis=None, keepdims=True)
|
||||
np.testing.assert_equal(result["min"], np.array(1337))
|
||||
np.testing.assert_equal(result["max"], np.array(1337))
|
||||
np.testing.assert_equal(result["mean"], np.array(1337.0))
|
||||
np.testing.assert_equal(result["std"], np.array(0.0))
|
||||
np.testing.assert_equal(result["count"], np.array([1]))
|
||||
|
||||
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
valid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
}
|
||||
}
|
||||
]
|
||||
_assert_type_and_shape(valid_stats)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_invalid_type():
|
||||
invalid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"min": [1.0], # Not a numpy array
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
}
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_invalid_shape():
|
||||
invalid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"count": np.array([1, 2]), # Wrong shape
|
||||
}
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_aggregate_feature_stats():
|
||||
stats_ft_list = [
|
||||
{
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
},
|
||||
{
|
||||
"min": np.array([2.0]),
|
||||
"max": np.array([12.0]),
|
||||
"mean": np.array([6.0]),
|
||||
"std": np.array([2.5]),
|
||||
"count": np.array([1]),
|
||||
},
|
||||
]
|
||||
result = aggregate_feature_stats(stats_ft_list)
|
||||
np.testing.assert_allclose(result["min"], np.array([1.0]))
|
||||
np.testing.assert_allclose(result["max"], np.array([12.0]))
|
||||
np.testing.assert_allclose(result["mean"], np.array([5.5]))
|
||||
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
|
||||
np.testing.assert_allclose(result["count"], np.array([2]))
|
||||
|
||||
|
||||
def test_aggregate_stats():
|
||||
all_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"min": [1, 2, 3],
|
||||
"max": [10, 20, 30],
|
||||
"mean": [5.5, 10.5, 15.5],
|
||||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
"observation.image": {
|
||||
"min": [2, 1, 0],
|
||||
"max": [15, 10, 5],
|
||||
"mean": [8.5, 5.5, 2.5],
|
||||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
|
||||
expected_agg_stats = {
|
||||
"observation.image": {
|
||||
"min": [1, 1, 0],
|
||||
"max": [15, 20, 30],
|
||||
"mean": [7.3, 7.5, 7.7],
|
||||
"std": [3.5317, 4.8267, 8.5581],
|
||||
"count": 25,
|
||||
},
|
||||
"observation.state": {
|
||||
"min": 1,
|
||||
"max": 15,
|
||||
"mean": 7.3,
|
||||
"std": 3.5317,
|
||||
"count": 25,
|
||||
},
|
||||
"extra_key_0": {
|
||||
"min": 5,
|
||||
"max": 25,
|
||||
"mean": 15.0,
|
||||
"std": 6.0,
|
||||
"count": 6,
|
||||
},
|
||||
"extra_key_1": {
|
||||
"min": 0,
|
||||
"max": 20,
|
||||
"mean": 10.0,
|
||||
"std": 5.0,
|
||||
"count": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# cast to numpy
|
||||
for ep_stats in all_stats:
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
# cast to numpy
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
results = aggregate_stats(all_stats)
|
||||
|
||||
for fkey in expected_agg_stats:
|
||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
582
tests/datasets/test_datasets.py
Normal file
582
tests/datasets/test_datasets.py
Normal file
@@ -0,0 +1,582 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
"shape": DUMMY_CHW,
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
assert init_attr == create_attr
|
||||
|
||||
|
||||
def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||
kwargs = {
|
||||
"repo_id": DUMMY_REPO_ID,
|
||||
"total_episodes": 10,
|
||||
"total_frames": 400,
|
||||
"episodes": [2, 5, 6],
|
||||
}
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1)})
|
||||
|
||||
|
||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||
):
|
||||
dataset.add_frame({"task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
assert dataset[0]["task_index"] == 0
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
|
||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
|
||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
|
||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
|
||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
|
||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
|
||||
def test_add_frame_image_wrong_shape(image_dataset):
|
||||
dataset = image_dataset
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
|
||||
),
|
||||
):
|
||||
c, h, w = DUMMY_CHW
|
||||
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_image_wrong_range(image_dataset):
|
||||
"""This test will display the following error message from a thread:
|
||||
```
|
||||
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
|
||||
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
|
||||
Please adjust the range or provide a uint8 image with values in the range [0, 255]
|
||||
```
|
||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||
"""
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||
with pytest.raises(FileNotFoundError):
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_add_frame_image_uint8(image_dataset):
|
||||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_add_frame_image_pil(image_dataset):
|
||||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
image = np.random.rand(*DUMMY_HWC) * 255
|
||||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
# Single dataset
|
||||
lerobot.env_dataset_policy_triplets,
|
||||
# Multi-dataset
|
||||
# TODO after fix multidataset
|
||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
Tests that:
|
||||
- we can create a dataset with the factory.
|
||||
- for a commonly used set of data keys, the data dimensions are correct.
|
||||
"""
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||
env=make_env_config(env_name),
|
||||
policy=make_policy_config(policy_name),
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
]
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
if key not in item:
|
||||
if required:
|
||||
assert key in item, f"{key}"
|
||||
else:
|
||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
||||
continue
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
assert item[key].ndim == ndim + 1, f"{key}"
|
||||
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
|
||||
else:
|
||||
assert item[key].ndim == ndim, f"{key}"
|
||||
|
||||
if key in camera_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= 0.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None and key in delta_timestamps:
|
||||
# test t,c,h,w
|
||||
assert item[key].shape[1] == 3, f"{key}"
|
||||
else:
|
||||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/pusht",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
# (michel-aractingi) commenting the two datasets from openx as test is failing
|
||||
# "lerobot/nyu_franka_play_dataset",
|
||||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
|
||||
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0])
|
||||
|
||||
test_dir = Path("tests/artifacts/datasets") / repo_id
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
# ignore language instructions (if exists) in language conditioned datasets
|
||||
# TODO (michel-aractingi): transform language obs to language embeddings via tokenizer
|
||||
new_frame.pop("language_instruction", None)
|
||||
old_frame.pop("language_instruction", None)
|
||||
new_frame.pop("task", None)
|
||||
old_frame.pop("task", None)
|
||||
|
||||
# Remove task_index to allow for backward compatibility
|
||||
# TODO(rcadene): remove when new features have been generated
|
||||
if "task_index" not in old_frame:
|
||||
del new_frame["task_index"]
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
for key in new_frame:
|
||||
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
|
||||
f"{key=} for index={i} does not contain the same value"
|
||||
)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i + 1)
|
||||
|
||||
# # test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires internet access")
|
||||
def test_create_branch():
|
||||
api = HfApi()
|
||||
|
||||
repo_id = "cadene/test_create_branch"
|
||||
repo_type = "dataset"
|
||||
branch = "test"
|
||||
ref = f"refs/heads/{branch}"
|
||||
|
||||
# Prepare a repo with a test branch
|
||||
api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True)
|
||||
api.create_repo(repo_id, repo_type=repo_type)
|
||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
# Make sure the test branch exists
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
assert ref in refs
|
||||
|
||||
# Overwrite it
|
||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
278
tests/datasets/test_delta_timestamps.py
Normal file
278
tests/datasets/test_delta_timestamps.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from itertools import accumulate
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
|
||||
"to": np.array(cumulative_lengths, dtype=np.int64),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
delta_timestamps[key][3] += tolerance_s * 1.1
|
||||
return delta_timestamps
|
||||
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
delta_timestamps[key][3] += tolerance_s * 0.9
|
||||
delta_timestamps[key][-3] += tolerance_s * 0.9
|
||||
return delta_timestamps
|
||||
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx = np.array([0.0]), np.array([0])
|
||||
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=valid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_delta_timestamps_empty():
|
||||
delta_timestamps = {}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices_factory):
|
||||
fps = 50
|
||||
min_max_range = (-100, 100)
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, min_max_range=min_max_range)
|
||||
expected_delta_indices = delta_indices_factory(min_max_range=min_max_range)
|
||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
||||
assert expected_delta_indices == actual_delta_indices
|
||||
374
tests/datasets/test_image_transforms.py
Normal file
374
tests/datasets/test_image_transforms.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
RandomSubsetApply,
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
v2.ColorJitter(brightness=0.5),
|
||||
v2.ColorJitter(contrast=0.5),
|
||||
v2.ColorJitter(saturation=0.5),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor(single_transforms):
|
||||
return single_transforms["original_frame"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform_enable_false(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig() # default is enable=False
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
max_num_transforms=5,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
random_order=True,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
tf = ImageTransforms(tf_cfg)
|
||||
|
||||
with seeded_context(1338):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img_tensor))
|
||||
|
||||
tmp_img_tensor = img_tensor
|
||||
for sub_tf in tf.tf.selected_transforms:
|
||||
tmp_img_tensor = sub_tf(tmp_img_tensor)
|
||||
torch.testing.assert_close(tmp_img_tensor, out_imgs[-1])
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tf_type, tf_name, min_max_values",
|
||||
[
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_single_transforms(
|
||||
img_tensor, tf_type, tf_name, min_max_values, single_transforms
|
||||
):
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
actual = tf(img_tensor)
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img_tensor)
|
||||
|
||||
expected = default_transforms["default"]
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(img_tensor_factory, p):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img_tensor)
|
||||
|
||||
p_horz, _ = p
|
||||
if p_horz:
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
|
||||
else:
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||
actual = random_order(img_tensor)
|
||||
expected = v2.Compose(flips)(img_tensor)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
|
||||
img_tensor = img_tensor_factory()
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_subset", [0, 5])
|
||||
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_min_negative():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((-0.1, 2.0))
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(combined_transforms_dir.iterdir()), (
|
||||
"No transformed images found in combined transforms directory."
|
||||
)
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (combined_transforms_dir / f"{i}.png").exists(), (
|
||||
f"Combined transform image {i}.png was not found."
|
||||
)
|
||||
|
||||
|
||||
def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
for file_name in expected_files:
|
||||
assert (transform_dir / file_name).exists(), (
|
||||
f"{file_name} was not found in {transform} directory."
|
||||
)
|
||||
386
tests/datasets/test_image_writer.py
Normal file
386
tests/datasets/test_image_writer.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import queue
|
||||
import time
|
||||
from multiprocessing import queues
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_pil_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_HWC
|
||||
|
||||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
|
||||
def test_init_threading():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 0
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queue.Queue)
|
||||
assert len(writer.threads) == 2
|
||||
assert len(writer.processes) == 0
|
||||
assert all(t.is_alive() for t in writer.threads)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_init_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 2
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queues.JoinableQueue)
|
||||
assert len(writer.threads) == 0
|
||||
assert len(writer.processes) == 2
|
||||
assert all(p.is_alive() for p in writer.processes)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_zero_threads():
|
||||
with pytest.raises(ValueError):
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
|
||||
image = np.random.rand(*DUMMY_HWC) * 255
|
||||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
|
||||
image = np.random.rand(*DUMMY_HWC) * 2 - 1
|
||||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_rgb(img_array_factory):
|
||||
img_array = img_array_factory(100, 100)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||
img_array = img_array_factory(channels=4)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_float_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_uint8_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
|
||||
|
||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
|
||||
|
||||
def test_write_image_image(tmp_path, img_factory):
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_pil, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_invalid_data(tmp_path):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
time.sleep(1)
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_stop():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_stop_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(p.is_alive() for p in writer.processes)
|
||||
|
||||
|
||||
def test_multiple_stops():
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_multiple_stops_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_exception_handling(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
with (
|
||||
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
|
||||
pytest.raises(queue.Full) as exc_info,
|
||||
):
|
||||
writer.save_image(image_array, tmp_path / "test.png")
|
||||
assert str(exc_info.value) == "Queue is full"
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
formats = ["png", "jpeg", "bmp"]
|
||||
for fmt in formats:
|
||||
fpath = tmp_path / f"test_image.{fmt}"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
raise Exception("Test exception")
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
function_that_raises_exception(dataset=dataset)
|
||||
|
||||
assert str(exc_info.value) == "Test exception"
|
||||
dataset.image_writer.stop.assert_called_once()
|
||||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
start_time = time.perf_counter()
|
||||
writer.save_image(image_tensor, fpath)
|
||||
end_time = time.perf_counter()
|
||||
time_spent = end_time - start_time
|
||||
# Might need to adjust this threshold depending on hardware
|
||||
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
282
tests/datasets/test_online_buffer.py
Normal file
282
tests/datasets/test_online_buffer.py
Normal file
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.d
|
||||
from copy import deepcopy
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
data_shape = (2, 3) # just some arbitrary > 1D shape
|
||||
buffer_capacity = 100
|
||||
fps = 10
|
||||
|
||||
|
||||
def make_new_buffer(
|
||||
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
|
||||
) -> tuple[OnlineBuffer, str]:
|
||||
if write_dir is None:
|
||||
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
|
||||
buffer = OnlineBuffer(
|
||||
write_dir,
|
||||
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
|
||||
buffer_capacity=buffer_capacity,
|
||||
fps=fps,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
}
|
||||
return new_data
|
||||
|
||||
|
||||
def test_non_mutate():
|
||||
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
|
||||
|
||||
This means that mutating the data in the buffer does not mutate the original data.
|
||||
|
||||
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
|
||||
a success case for `test_write_read`.
|
||||
"""
|
||||
buffer, _ = make_new_buffer()
|
||||
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
|
||||
new_data_copy = deepcopy(new_data)
|
||||
buffer.add_data(new_data)
|
||||
buffer._data[data_key][:] += 1
|
||||
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
|
||||
|
||||
|
||||
def test_index_error_no_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
with pytest.raises(IndexError):
|
||||
buffer[0]
|
||||
|
||||
|
||||
def test_index_error_with_data():
|
||||
buffer, _ = make_new_buffer()
|
||||
n_frames = buffer_capacity // 2
|
||||
new_data = make_spoof_data_frames(1, n_frames)
|
||||
buffer.add_data(new_data)
|
||||
with pytest.raises(IndexError):
|
||||
buffer[n_frames]
|
||||
with pytest.raises(IndexError):
|
||||
buffer[-n_frames - 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_reload", [False, True])
|
||||
def test_write_read(do_reload: bool):
|
||||
"""Checks that data can be added to the buffer and read back.
|
||||
|
||||
If do_reload we delete the buffer object and load the buffer back from disk before reading.
|
||||
"""
|
||||
buffer, write_dir = make_new_buffer()
|
||||
n_episodes = 2
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
|
||||
if do_reload:
|
||||
del buffer
|
||||
buffer, _ = make_new_buffer(write_dir)
|
||||
|
||||
assert len(buffer) == n_frames_per_episode * n_episodes
|
||||
for i, item in enumerate(buffer):
|
||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
||||
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
|
||||
|
||||
|
||||
def test_read_data_key():
|
||||
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
|
||||
buffer, _ = make_new_buffer()
|
||||
n_episodes = 2
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
|
||||
data_from_buffer = buffer.get_data_by_key(data_key)
|
||||
assert isinstance(data_from_buffer, torch.Tensor)
|
||||
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
|
||||
|
||||
|
||||
def test_fifo():
|
||||
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
|
||||
buffer, _ = make_new_buffer()
|
||||
n_frames_per_episode = buffer_capacity // 4
|
||||
n_episodes = 3
|
||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
||||
buffer.add_data(new_data)
|
||||
n_more_episodes = 2
|
||||
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
||||
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
|
||||
"Something went wrong with the test code."
|
||||
)
|
||||
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
||||
buffer.add_data(more_new_data)
|
||||
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
||||
|
||||
expected_data = {}
|
||||
for k in new_data:
|
||||
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
|
||||
expected_data[k] = np.roll(
|
||||
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
|
||||
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for i, item in enumerate(buffer):
|
||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
||||
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
|
||||
|
||||
|
||||
def test_delta_timestamps_within_tolerance():
|
||||
"""Check that getting an item with delta_timestamps within tolerance succeeds.
|
||||
|
||||
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||
"""Check that getting an item with delta_timestamps outside of tolerance fails.
|
||||
|
||||
We expect it to fail if and only if the requested timestamps are within the episode range.
|
||||
|
||||
Note: Copied from
|
||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
with pytest.raises(AssertionError):
|
||||
buffer[2]
|
||||
|
||||
|
||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
"""Check that copy-padding of timestamps outside of the episode range works.
|
||||
|
||||
Note: Copied from
|
||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
|
||||
"""
|
||||
# Sanity check on global fps as we are assuming it is 10 here.
|
||||
assert fps == 10, "This test assumes fps==10"
|
||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
|
||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
||||
buffer.add_data(new_data)
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
||||
"Padding does not match expected values"
|
||||
)
|
||||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path,
|
||||
offline_dataset_size: int,
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights /= expected_weights.sum()
|
||||
torch.testing.assert_close(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=1,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
90
tests/datasets/test_sampler.py
Normal file
90
tests/datasets/test_sampler.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def test_drop_n_first_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
|
||||
|
||||
def test_drop_n_last_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
|
||||
|
||||
def test_episode_indices_to_use():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
|
||||
|
||||
def test_shuffle():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
55
tests/datasets/test_utils.py
Normal file
55
tests/datasets/test_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
33
tests/datasets/test_visualize_dataset.py
Normal file
33
tests/datasets/test_visualize_dataset.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: add dummy videos")
|
||||
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
rrd_path = visualize_dataset(
|
||||
dataset,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
Reference in New Issue
Block a user