forked from tangger/lerobot
fix(tests): remove lint warnings/errors
This commit is contained in:
68
tests/fixtures/dataset_factories.py
vendored
68
tests/fixtures/dataset_factories.py
vendored
@@ -52,16 +52,16 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
@pytest.fixture(name="img_tensor_factory", scope="session")
|
||||
def fixture_img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
@pytest.fixture(name="img_array_factory", scope="session")
|
||||
def fixture_img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
@@ -76,8 +76,8 @@ def img_array_factory():
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
@pytest.fixture(name="img_factory", scope="session")
|
||||
def fixture_img_factory(img_array_factory):
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
@@ -85,13 +85,17 @@ def img_factory(img_array_factory):
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def features_factory():
|
||||
@pytest.fixture(name="features_factory", scope="session")
|
||||
def fixture_features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
@@ -107,8 +111,8 @@ def features_factory():
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
@pytest.fixture(name="info_factory", scope="session")
|
||||
def fixture_info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
@@ -121,10 +125,14 @@ def info_factory(features_factory):
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
motor_features: dict | None = None,
|
||||
camera_features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if motor_features is None:
|
||||
motor_features = DUMMY_MOTOR_FEATURES
|
||||
if camera_features is None:
|
||||
camera_features = DUMMY_CAMERA_FEATURES
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
@@ -145,8 +153,8 @@ def info_factory(features_factory):
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
@pytest.fixture(name="stats_factory", scope="session")
|
||||
def fixture_stats_factory():
|
||||
def _create_stats(
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
@@ -175,8 +183,8 @@ def stats_factory():
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
@pytest.fixture(name="episodes_stats_factory", scope="session")
|
||||
def fixture_episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
@@ -192,8 +200,8 @@ def episodes_stats_factory(stats_factory):
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
@pytest.fixture(name="tasks_factory", scope="session")
|
||||
def fixture_tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
@@ -204,8 +212,8 @@ def tasks_factory():
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
@pytest.fixture(name="episodes_factory", scope="session")
|
||||
def fixture_episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
@@ -252,8 +260,8 @@ def episodes_factory(tasks_factory):
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
@pytest.fixture(name="hf_dataset_factory", scope="session")
|
||||
def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
@@ -310,8 +318,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session")
|
||||
def fixture_lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -364,8 +372,8 @@ def lerobot_dataset_metadata_factory(
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
@pytest.fixture(name="lerobot_dataset_factory", scope="session")
|
||||
def fixture_lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
@@ -443,6 +451,6 @@ def lerobot_dataset_factory(
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
@pytest.fixture(name="empty_lerobot_dataset_factory", scope="session")
|
||||
def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
Reference in New Issue
Block a user