Add LeRobotDatasetMetadata

This commit is contained in:
Simon Alibert
2024-11-03 18:07:37 +01:00
parent ac79e8cb36
commit e4ba084e25
25 changed files with 419 additions and 327 deletions

View File

@@ -8,7 +8,7 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
@@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
return shapes
def get_task_index(tasks_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info,
stats,
tasks,
episodes,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
**kwargs,
) -> LeRobotDatasetMetadata:
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
@@ -321,6 +362,7 @@ def lerobot_dataset_factory(
episodes,
hf_dataset,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
@@ -335,19 +377,26 @@ def lerobot_dataset_factory(
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
tasks_dicts=task_dicts,
episodes_dicts=episode_dicts,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_ds,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
**kwargs,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

View File

@@ -36,11 +36,11 @@ def stats_path(stats):
@pytest.fixture(scope="session")
def tasks_path(tasks):
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks_dicts)
writer.write_all(task_dicts)
return fpath
return _create_tasks_jsonl_file

View File

@@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
):
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
@@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes_dicts:
for episode_dict in episode_dicts:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
@@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks_dicts)
_ = tasks_path(local_dir, task_dicts)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes_dicts)
_ = episode_path(local_dir, episode_dicts)
else:
pass
return str(local_dir)