Most unit tests are passing
This commit is contained in:
38
tests/fixtures/hub.py
vendored
38
tests/fixtures/hub.py
vendored
@@ -1,13 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_EPISODES_STATS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
@@ -21,8 +21,6 @@ def mock_snapshot_download_factory(
|
||||
create_info,
|
||||
stats_factory,
|
||||
create_stats,
|
||||
episodes_stats_factory,
|
||||
create_episodes_stats,
|
||||
tasks_factory,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
@@ -38,46 +36,43 @@ def mock_snapshot_download_factory(
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: datasets.Dataset | None = None,
|
||||
tasks: datasets.Dataset | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
if local_dir is None:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
# TODO(rcadene)
|
||||
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
@@ -89,7 +84,6 @@ def mock_snapshot_download_factory(
|
||||
has_info = False
|
||||
has_tasks = False
|
||||
has_episodes = False
|
||||
has_episodes_stats = False
|
||||
has_stats = False
|
||||
has_data = False
|
||||
for rel_path in allowed_files:
|
||||
@@ -99,8 +93,6 @@ def mock_snapshot_download_factory(
|
||||
has_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
has_tasks = True
|
||||
elif rel_path.startswith("meta/episodes_stats"):
|
||||
has_episodes_stats = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
has_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
@@ -116,8 +108,6 @@ def mock_snapshot_download_factory(
|
||||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if has_episodes_stats:
|
||||
create_episodes_stats(local_dir, episodes_stats)
|
||||
if has_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user