Fix tests

This commit is contained in:
Simon Alibert
2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

46
tests/fixtures/hub.py vendored
View File

@@ -1,5 +1,6 @@
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
@@ -9,16 +10,16 @@ from tests.fixtures.defaults import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info,
info_factory,
info_path,
stats,
stats_factory,
stats_path,
tasks,
tasks_factory,
tasks_path,
episodes,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -26,8 +27,25 @@ def mock_snapshot_download_factory(
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
@@ -53,10 +71,10 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files)
data_files = []
for episode_dict in episode_dicts:
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
@@ -69,15 +87,15 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index)
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info_dict)
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
_ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, task_dicts)
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episode_dicts)
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)