Mock snapshot_download

This commit is contained in:
Simon Alibert
2024-11-01 10:58:09 +01:00
parent 5ea7c78237
commit cd1509d805
4 changed files with 119 additions and 15 deletions

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from unittest.mock import patch
import datasets
import numpy as np
@@ -223,31 +224,39 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
info_path,
stats,
stats_path,
episodes,
episode_path,
tasks,
tasks_path,
hf_dataset,
multi_episode_parquet_path,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset(
root: Path,
info_dict: dict = info,
stats_dict: dict = stats,
episode_dicts: list[dict] = episodes,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
hf_ds: datasets.Dataset = hf_dataset,
**kwargs,
) -> LeRobotDataset:
root.mkdir(parents=True, exist_ok=True)
# Create local files
_ = info_path(root, info_dict)
_ = stats_path(root, stats_dict)
_ = tasks_path(root, task_dicts)
_ = episode_path(root, episode_dicts)
_ = multi_episode_parquet_path(root, hf_ds)
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True)
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
tasks_dicts=task_dicts,
episodes_dicts=episode_dicts,
hf_ds=hf_ds,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
return _create_lerobot_dataset