Merge branch 'main' into user/adil-zouitine/2025-1-7-port-hil-serl-new

This commit is contained in:
Adil Zouitine
2025-05-06 16:43:44 +02:00
committed by AdilZouitine
272 changed files with 14894 additions and 11422 deletions

View File

@@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@@ -127,7 +129,9 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
ValueError,
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
dataset.add_frame(
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
)
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
@@ -137,7 +141,9 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
ValueError,
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
dataset.add_frame(
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
)
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -145,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
match=re.escape(
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
@@ -167,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
match=re.escape(
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
@@ -461,7 +471,9 @@ def test_flatten_unflatten_dict():
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
f"{original_d} != {d}"
)
@pytest.mark.parametrize(
@@ -515,7 +527,13 @@ def test_backward_compatibility(repo_id):
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
load_and_compare(i)
load_and_compare(i + 1)

View File

@@ -16,6 +16,7 @@
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
@@ -257,7 +258,14 @@ def test_backward_compatibility_single_transforms(
@require_x86_64_kernel
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.7.0"),
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
)
def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)