Most unit tests are passing

This commit is contained in:
Remi Cadene
2025-04-11 14:04:22 +02:00
parent c1b28f0b58
commit 34c5d4ce07
6 changed files with 391 additions and 322 deletions

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pytest
import torch
@@ -14,13 +15,12 @@ from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
@@ -36,8 +36,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(tasks: Dataset, task: str) -> int:
task_idx = tasks["task"].index(task)
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
# TODO(rcadene): a bit complicated no? ^^
task_idx = tasks.loc[task].task_index.item()
return task_idx
@@ -164,42 +165,44 @@ def stats_factory():
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
def _generator(total_episodes):
for ep_idx in range(total_episodes):
flat_ep_stats = flatten_dict(stats_factory(features))
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
# @pytest.fixture(scope="session")
# def episodes_stats_factory(stats_factory):
# def _create_episodes_stats(
# features: dict[str],
# total_episodes: int = 3,
# ) -> dict:
# Simpler to rely on generator instead of from_dict
return Dataset.from_generator(lambda: _generator(total_episodes))
# def _generator(total_episodes):
# for ep_idx in range(total_episodes):
# flat_ep_stats = flatten_dict(stats_factory(features))
# flat_ep_stats["episode_index"] = ep_idx
# yield flat_ep_stats
return _create_episodes_stats
# # Simpler to rely on generator instead of from_dict
# return Dataset.from_generator(lambda: _generator(total_episodes))
# return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> Dataset:
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
return Dataset.from_dict({"task_index": ids, "task": tasks})
df = pd.DataFrame({"task_index": ids}, index=tasks)
return df
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
tasks: dict | None = None,
tasks: pd.DataFrame | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
@@ -207,21 +210,24 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
if tasks is None:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
num_tasks_available = len(tasks)
if total_episodes < num_tasks_available and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
num_tasks_available = len(tasks["task"])
# Create empty dictionaries with all keys
d = {
"episode_index": [],
"meta/episodes/chunk_index": [],
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"tasks": [],
@@ -232,10 +238,13 @@ def episodes_factory(tasks_factory):
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
remaining_tasks = tasks["task"].copy()
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
@@ -243,15 +252,22 @@ def episodes_factory(tasks_factory):
d["episode_index"].append(ep_idx)
# TODO(rcadene): remove heuristic of only one file
d["meta/episodes/chunk_index"].append(0)
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
return Dataset.from_dict(d)
return _create_episodes
@@ -261,15 +277,15 @@ def episodes_factory(tasks_factory):
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
if tasks is None:
tasks = tasks_factory()
if not episodes:
if episodes is None:
episodes = episodes_factory()
if not features:
if features is None:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
@@ -282,6 +298,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
# TODO(rcadene): assign the tasks of the episode per chunks of frames
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -319,7 +337,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
@@ -329,30 +346,28 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
) -> LeRobotDatasetMetadata:
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
@@ -374,7 +389,6 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
@@ -390,25 +404,23 @@ def lerobot_dataset_factory(
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
episode_dicts: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes_metadata: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episode_dicts = episodes_factory(
episodes_metadata = episodes_factory(
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
@@ -416,14 +428,13 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
@@ -431,9 +442,8 @@ def lerobot_dataset_factory(
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,

View File

@@ -1,17 +1,13 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pandas as pd
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
from lerobot.common.datasets.utils import (
write_episodes,
write_episodes_stats,
write_hf_dataset,
write_info,
write_stats,
@@ -22,7 +18,7 @@ from lerobot.common.datasets.utils import (
@pytest.fixture(scope="session")
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if not info:
if info is None:
info = info_factory()
write_info(info, dir)
@@ -32,27 +28,27 @@ def create_info(info_factory):
@pytest.fixture(scope="session")
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if not stats:
if stats is None:
stats = stats_factory()
write_stats(stats, dir)
return _create_stats
@pytest.fixture(scope="session")
def create_episodes_stats(episodes_stats_factory):
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
if not episodes_stats:
episodes_stats = episodes_stats_factory()
write_episodes_stats(episodes_stats, dir)
# @pytest.fixture(scope="session")
# def create_episodes_stats(episodes_stats_factory):
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
# if episodes_stats is None:
# episodes_stats = episodes_stats_factory()
# write_episodes_stats(episodes_stats, dir)
return _create_episodes_stats
# return _create_episodes_stats
@pytest.fixture(scope="session")
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: Dataset | None = None):
if not tasks:
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
if tasks is None:
tasks = tasks_factory()
write_tasks(tasks, dir)
@@ -61,17 +57,18 @@ def create_tasks(tasks_factory):
@pytest.fixture(scope="session")
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: Dataset | None = None):
if not episodes:
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
episodes = episodes_factory()
write_episodes(episodes, dir)
return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
if not hf_dataset:
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
@@ -84,7 +81,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
@@ -108,7 +105,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()

38
tests/fixtures/hub.py vendored
View File

@@ -1,13 +1,13 @@
from pathlib import Path
import datasets
import pandas as pd
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH,
LEGACY_STATS_PATH,
@@ -21,8 +21,6 @@ def mock_snapshot_download_factory(
create_info,
stats_factory,
create_stats,
episodes_stats_factory,
create_episodes_stats,
tasks_factory,
create_tasks,
episodes_factory,
@@ -38,46 +36,43 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _mock_snapshot_download(
repo_id: str,
repo_id: str, # TODO(rcadene): repo_id should be used no?
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
if local_dir is None:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = [
INFO_PATH,
LEGACY_STATS_PATH,
# TODO(rcadene)
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
@@ -89,7 +84,6 @@ def mock_snapshot_download_factory(
has_info = False
has_tasks = False
has_episodes = False
has_episodes_stats = False
has_stats = False
has_data = False
for rel_path in allowed_files:
@@ -99,8 +93,6 @@ def mock_snapshot_download_factory(
has_stats = True
elif rel_path.startswith("meta/tasks"):
has_tasks = True
elif rel_path.startswith("meta/episodes_stats"):
has_episodes_stats = True
elif rel_path.startswith("meta/episodes"):
has_episodes = True
elif rel_path.startswith("data/"):
@@ -116,8 +108,6 @@ def mock_snapshot_download_factory(
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
if has_episodes_stats:
create_episodes_stats(local_dir, episodes_stats)
if has_data:
create_hf_dataset(local_dir, hf_dataset)