(fix): linter

This commit is contained in:
AdilZouitine
2025-06-03 17:45:10 +02:00
parent 6eeab64f8a
commit 8d4fe1ad6a
7 changed files with 16 additions and 37 deletions

View File

@@ -72,9 +72,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@@ -126,9 +124,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame(
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
@@ -137,9 +133,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame(
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
)
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -147,9 +141,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
),
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
@@ -171,9 +163,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
),
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
@@ -467,9 +457,7 @@ def test_flatten_unflatten_dict():
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
f"{original_d} != {d}"
)
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.parametrize(
@@ -523,13 +511,7 @@ def test_backward_compatibility(repo_id):
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)