Mock snapshot_download
This commit is contained in:
37
tests/fixtures/dataset_factories.py
vendored
37
tests/fixtures/dataset_factories.py
vendored
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -223,31 +224,39 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
info_path,
|
||||
stats,
|
||||
stats_path,
|
||||
episodes,
|
||||
episode_path,
|
||||
tasks,
|
||||
tasks_path,
|
||||
hf_dataset,
|
||||
multi_episode_parquet_path,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
hf_ds: datasets.Dataset = hf_dataset,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
# Create local files
|
||||
_ = info_path(root, info_dict)
|
||||
_ = stats_path(root, stats_dict)
|
||||
_ = tasks_path(root, task_dicts)
|
||||
_ = episode_path(root, episode_dicts)
|
||||
_ = multi_episode_parquet_path(root, hf_ds)
|
||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True)
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
tasks_dicts=task_dicts,
|
||||
episodes_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
|
||||
Reference in New Issue
Block a user