Format file

This commit is contained in:
AdilZouitine
2025-05-07 10:26:18 +02:00
parent adbf8bb85e
commit b36ec31fea
13 changed files with 43 additions and 169 deletions

View File

@@ -37,10 +37,7 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
import numpy as np import numpy as np
import pytest import pytest
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recorded by a camera. # Maximum absolute difference between two consecutive images recorded by a camera.
@@ -115,11 +112,7 @@ def test_camera(request, camera_type, mock):
) )
# TODO(rcadene): properly set `rtol` # TODO(rcadene): properly set `rtol`
np.testing.assert_allclose( np.testing.assert_allclose(
color_image, color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
async_color_image,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
# Test disconnecting # Test disconnecting
@@ -138,11 +131,7 @@ def test_camera(request, camera_type, mock):
assert camera.color_mode == "bgr" assert camera.color_mode == "bgr"
bgr_color_image = camera.read() bgr_color_image = camera.read()
np.testing.assert_allclose( np.testing.assert_allclose(
color_image, color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
bgr_color_image[:, :, [2, 1, 0]],
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
del camera del camera
@@ -177,11 +166,7 @@ def test_camera(request, camera_type, mock):
rot_color_image = camera.read() rot_color_image = camera.read()
np.testing.assert_allclose( np.testing.assert_allclose(
rot_color_image, rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
manual_rot_img,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
del camera del camera
@@ -215,9 +200,7 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
if camera_type == "opencv": if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense": elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import ( from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
save_images_from_cameras,
)
# Small `record_time_s` to speedup unit tests # Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
@@ -226,13 +209,7 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) @pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera @require_camera
def test_camera_rotation(request, camera_type, mock): def test_camera_rotation(request, camera_type, mock):
config_kwargs = { config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
"camera_type": camera_type,
"mock": mock,
"width": 640,
"height": 480,
"fps": 30,
}
# No rotation. # No rotation.
camera = make_camera(**config_kwargs, rotation=None) camera = make_camera(**config_kwargs, rotation=None)

View File

@@ -61,10 +61,7 @@ def test_sample_indices():
assert len(indices) == estimate_num_samples(10) assert len(indices) == estimate_num_samples(10)
@patch( @patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
)
def test_sample_images(mock_load): def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)] image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths) images = sample_images(image_paths)
@@ -148,8 +145,7 @@ def test_compute_episode_stats():
} }
with patch( with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
side_effect=mock_load_image_as_numpy,
): ):
stats = compute_episode_stats(episode_data, features) stats = compute_episode_stats(episode_data, features)
@@ -237,13 +233,7 @@ def test_aggregate_stats():
"std": [2.87, 5.87, 8.87], "std": [2.87, 5.87, 8.87],
"count": 10, "count": 10,
}, },
"observation.state": { "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"min": 1,
"max": 10,
"mean": 5.5,
"std": 2.87,
"count": 10,
},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
}, },
{ {
@@ -254,13 +244,7 @@ def test_aggregate_stats():
"std": [3.42, 2.42, 1.42], "std": [3.42, 2.42, 1.42],
"count": 15, "count": 15,
}, },
"observation.state": { "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"min": 2,
"max": 15,
"mean": 8.5,
"std": 3.42,
"count": 15,
},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
}, },
] ]
@@ -322,9 +306,6 @@ def test_aggregate_stats():
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose( np.testing.assert_allclose(
results[fkey]["std"], results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
expected_agg_stats[fkey]["std"],
atol=1e-04,
rtol=1e-04,
) )
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@@ -104,8 +104,7 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises( with pytest.raises(
ValueError, ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n",
): ):
dataset.add_frame({"state": torch.randn(1)}) dataset.add_frame({"state": torch.randn(1)})
@@ -114,8 +113,7 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises( with pytest.raises(
ValueError, ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n",
): ):
dataset.add_frame({"task": "Dummy task"}) dataset.add_frame({"task": "Dummy task"})
@@ -124,8 +122,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises( with pytest.raises(
ValueError, ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
): ):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
@@ -134,8 +131,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises( with pytest.raises(
ValueError, ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
): ):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})

View File

@@ -54,9 +54,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory): def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps( def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
fps: int = 30,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps) hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy() timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
@@ -93,9 +91,7 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def valid_delta_timestamps_factory(): def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps( def _create_valid_delta_timestamps(
fps: int = 30, fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
keys: list = DUMMY_MOTOR_FEATURES,
min_max_range: tuple[int, int] = (-10, 10),
) -> dict: ) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys} delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
return delta_timestamps return delta_timestamps

View File

@@ -33,9 +33,7 @@ from lerobot.scripts.visualize_image_transforms import (
save_all_transforms, save_all_transforms,
save_each_transform, save_each_transform,
) )
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ( from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
ARTIFACT_DIR,
)
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -93,8 +91,7 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max): def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory() img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig( tf_cfg = ImageTransformsConfig(
enable=True, enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})},
) )
tf_actual = ImageTransforms(tf_cfg) tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(contrast=min_max) tf_expected = v2.ColorJitter(contrast=min_max)
@@ -117,8 +114,7 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max): def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory() img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig( tf_cfg = ImageTransformsConfig(
enable=True, enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})},
) )
tf_actual = ImageTransforms(tf_cfg) tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(hue=min_max) tf_expected = v2.ColorJitter(hue=min_max)
@@ -371,11 +367,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory # Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [ expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
"min.png",
"max.png",
"mean.png",
]
for file_name in expected_files: for file_name in expected_files:
assert (transform_dir / file_name).exists(), ( assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory." f"{file_name} was not found in {transform} directory."

View File

@@ -227,9 +227,7 @@ def test_compute_sampler_weights_trivial(
) )
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
) )
if offline_dataset_size == 0 or online_dataset_size == 0: if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size) expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
@@ -248,13 +246,10 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8 online_sampling_ratio = 0.8
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
) )
torch.testing.assert_close( torch.testing.assert_close(
weights, weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
) )
@@ -264,14 +259,10 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
) )
torch.testing.assert_close( torch.testing.assert_close(
weights, weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
) )

View File

@@ -15,9 +15,7 @@
# limitations under the License. # limitations under the License.
from datasets import Dataset from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
calculate_episode_data_index,
)
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,

View File

@@ -18,13 +18,8 @@ import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import DatasetCard from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
calculate_episode_data_index, from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
)
from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
hf_transform_to_torch,
)
def test_default_parameters(): def test_default_parameters():

View File

@@ -20,39 +20,17 @@ DUMMY_MOTOR_FEATURES = {
"action": { "action": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": [ "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
"state": { "state": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": [ "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": { "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"shape": (480, 640, 3), "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = { DUMMY_VIDEO_INFO = {

View File

@@ -23,11 +23,7 @@ import PIL.Image
import pytest import pytest
import torch import torch
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES, DEFAULT_FEATURES,
@@ -201,10 +197,7 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int: def _create_tasks(total_tasks: int = 3) -> int:
tasks = {} tasks = {}
for task_index in range(total_tasks): for task_index in range(total_tasks):
task_dict = { task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
tasks[task_index] = task_dict tasks[task_index] = task_dict
return tasks return tasks
@@ -282,10 +275,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate( episode_index_col = np.concatenate(
( (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
) )
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -350,9 +340,7 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
episodes = episodes_factory( episodes = episodes_factory(
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_frames=info["total_frames"],
tasks=tasks,
) )
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
@@ -404,9 +392,7 @@ def lerobot_dataset_factory(
) -> LeRobotDataset: ) -> LeRobotDataset:
if not info: if not info:
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_frames=total_frames,
total_tasks=total_tasks,
) )
if not stats: if not stats:
stats = stats_factory(features=info["features"]) stats = stats_factory(features=info["features"])

View File

@@ -102,10 +102,7 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory): def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet( def _create_single_episode_parquet(
dir: Path, dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
) -> Path: ) -> Path:
if not info: if not info:
info = info_factory() info = info_factory()

16
tests/fixtures/hub.py vendored
View File

@@ -67,9 +67,7 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
episodes = episodes_factory( episodes = episodes_factory(
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_frames=info["total_frames"],
tasks=tasks,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
@@ -95,13 +93,7 @@ def mock_snapshot_download_factory(
# List all possible files # List all possible files
all_files = [] all_files = []
meta_files = [ meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
INFO_PATH,
STATS_PATH,
EPISODES_STATS_PATH,
TASKS_PATH,
EPISODES_PATH,
]
all_files.extend(meta_files) all_files.extend(meta_files)
data_files = [] data_files = []
@@ -113,9 +105,7 @@ def mock_snapshot_download_factory(
all_files.extend(data_files) all_files.extend(data_files)
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
) )
# Create allowed files # Create allowed files

View File

@@ -43,10 +43,7 @@ import time
import numpy as np import numpy as np
import pytest import pytest
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.scripts.find_motors_bus_port import find_port from lerobot.scripts.find_motors_bus_port import find_port
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor