Add LeRobotDatasetMetadata
This commit is contained in:
67
tests/fixtures/dataset_factories.py
vendored
67
tests/fixtures/dataset_factories.py
vendored
@@ -8,7 +8,7 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
@@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
|
||||
return shapes
|
||||
|
||||
|
||||
def get_task_index(tasks_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
@@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
**kwargs,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
@@ -321,6 +362,7 @@ def lerobot_dataset_factory(
|
||||
episodes,
|
||||
hf_dataset,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
@@ -335,19 +377,26 @@ def lerobot_dataset_factory(
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
tasks_dicts=task_dicts,
|
||||
episodes_dicts=episode_dicts,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
**kwargs,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
4
tests/fixtures/files.py
vendored
4
tests/fixtures/files.py
vendored
@@ -36,11 +36,11 @@ def stats_path(stats):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
|
||||
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks_dicts)
|
||||
writer.write_all(task_dicts)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
||||
8
tests/fixtures/hub.py
vendored
8
tests/fixtures/hub.py
vendored
@@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
|
||||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
|
||||
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
|
||||
):
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
@@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes_dicts:
|
||||
for episode_dict in episode_dicts:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info_dict["chunks_size"]
|
||||
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
@@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats_dict)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks_dicts)
|
||||
_ = tasks_path(local_dir, task_dicts)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes_dicts)
|
||||
_ = episode_path(local_dir, episode_dicts)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
||||
Reference in New Issue
Block a user