Most unit tests are passing

This commit is contained in:
Remi Cadene
2025-04-11 14:04:22 +02:00
parent c1b28f0b58
commit 34c5d4ce07
6 changed files with 391 additions and 322 deletions

38
tests/fixtures/hub.py vendored
View File

@@ -1,13 +1,13 @@
from pathlib import Path
import datasets
import pandas as pd
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH,
LEGACY_STATS_PATH,
@@ -21,8 +21,6 @@ def mock_snapshot_download_factory(
create_info,
stats_factory,
create_stats,
episodes_stats_factory,
create_episodes_stats,
tasks_factory,
create_tasks,
episodes_factory,
@@ -38,46 +36,43 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _mock_snapshot_download(
repo_id: str,
repo_id: str, # TODO(rcadene): repo_id should be used no?
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
if local_dir is None:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = [
INFO_PATH,
LEGACY_STATS_PATH,
# TODO(rcadene)
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
@@ -89,7 +84,6 @@ def mock_snapshot_download_factory(
has_info = False
has_tasks = False
has_episodes = False
has_episodes_stats = False
has_stats = False
has_data = False
for rel_path in allowed_files:
@@ -99,8 +93,6 @@ def mock_snapshot_download_factory(
has_stats = True
elif rel_path.startswith("meta/tasks"):
has_tasks = True
elif rel_path.startswith("meta/episodes_stats"):
has_episodes_stats = True
elif rel_path.startswith("meta/episodes"):
has_episodes = True
elif rel_path.startswith("data/"):
@@ -116,8 +108,6 @@ def mock_snapshot_download_factory(
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
if has_episodes_stats:
create_episodes_stats(local_dir, episodes_stats)
if has_data:
create_hf_dataset(local_dir, hf_dataset)