Most unit tests are passing

This commit is contained in:
Remi Cadene
2025-04-11 14:04:22 +02:00
parent c1b28f0b58
commit 34c5d4ce07
6 changed files with 391 additions and 322 deletions

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pytest
import torch
@@ -14,13 +15,12 @@ from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
@@ -36,8 +36,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(tasks: Dataset, task: str) -> int:
task_idx = tasks["task"].index(task)
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
# TODO(rcadene): a bit complicated no? ^^
task_idx = tasks.loc[task].task_index.item()
return task_idx
@@ -164,42 +165,44 @@ def stats_factory():
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
def _generator(total_episodes):
for ep_idx in range(total_episodes):
flat_ep_stats = flatten_dict(stats_factory(features))
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
# @pytest.fixture(scope="session")
# def episodes_stats_factory(stats_factory):
# def _create_episodes_stats(
# features: dict[str],
# total_episodes: int = 3,
# ) -> dict:
# Simpler to rely on generator instead of from_dict
return Dataset.from_generator(lambda: _generator(total_episodes))
# def _generator(total_episodes):
# for ep_idx in range(total_episodes):
# flat_ep_stats = flatten_dict(stats_factory(features))
# flat_ep_stats["episode_index"] = ep_idx
# yield flat_ep_stats
return _create_episodes_stats
# # Simpler to rely on generator instead of from_dict
# return Dataset.from_generator(lambda: _generator(total_episodes))
# return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> Dataset:
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
return Dataset.from_dict({"task_index": ids, "task": tasks})
df = pd.DataFrame({"task_index": ids}, index=tasks)
return df
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
tasks: dict | None = None,
tasks: pd.DataFrame | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
@@ -207,21 +210,24 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
if tasks is None:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
num_tasks_available = len(tasks)
if total_episodes < num_tasks_available and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
num_tasks_available = len(tasks["task"])
# Create empty dictionaries with all keys
d = {
"episode_index": [],
"meta/episodes/chunk_index": [],
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"tasks": [],
@@ -232,10 +238,13 @@ def episodes_factory(tasks_factory):
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
remaining_tasks = tasks["task"].copy()
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
@@ -243,15 +252,22 @@ def episodes_factory(tasks_factory):
d["episode_index"].append(ep_idx)
# TODO(rcadene): remove heuristic of only one file
d["meta/episodes/chunk_index"].append(0)
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
return Dataset.from_dict(d)
return _create_episodes
@@ -261,15 +277,15 @@ def episodes_factory(tasks_factory):
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
if tasks is None:
tasks = tasks_factory()
if not episodes:
if episodes is None:
episodes = episodes_factory()
if not features:
if features is None:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
@@ -282,6 +298,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
# TODO(rcadene): assign the tasks of the episode per chunks of frames
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -319,7 +337,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
@@ -329,30 +346,28 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
) -> LeRobotDatasetMetadata:
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
@@ -374,7 +389,6 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
@@ -390,25 +404,23 @@ def lerobot_dataset_factory(
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
episode_dicts: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes_metadata: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episode_dicts = episodes_factory(
episodes_metadata = episodes_factory(
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
@@ -416,14 +428,13 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
@@ -431,9 +442,8 @@ def lerobot_dataset_factory(
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,