(fix): test

This commit is contained in:
AdilZouitine
2025-06-03 18:42:41 +02:00
parent 8d4fe1ad6a
commit 00e9f61509
2 changed files with 230 additions and 33 deletions

View File

@@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
assert dataset.num_frames == len(dataset)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
):
dataset.add_frame({"state": torch.randn(1)})
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"task": "Dummy task"})
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
@@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
@@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
@@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
dataset.add_frame({"state": 1.0}, task="Dummy task")
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
dataset.save_episode()
assert len(dataset) == 1
@@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
@@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
@@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
@@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
def test_add_frame_image_wrong_range(image_dataset):
@@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset):
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
with pytest.raises(FileNotFoundError):
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset):
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset):
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.add_frame({"image": image}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset):
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)