Commit before episodes episodes_stats merging

This commit is contained in:
Remi Cadene
2025-04-09 15:20:15 +02:00
parent 53ecec5fb2
commit c1b28f0b58
12 changed files with 905 additions and 396 deletions

View File

@@ -9,13 +9,16 @@ import numpy as np
import PIL.Image
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_DATA_PATH,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
)
@@ -33,10 +36,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
def get_task_index(tasks: Dataset, task: str) -> int:
task_idx = tasks["task"].index(task)
return task_idx
@pytest.fixture(scope="session")
@@ -104,9 +106,9 @@ def info_factory(features_factory):
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
files_size_in_mb: float = DEFAULT_FILE_SIZE_IN_MB,
data_path: str = DEFAULT_DATA_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
@@ -120,8 +122,8 @@ def info_factory(features_factory):
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"files_size_in_mb": files_size_in_mb,
"fps": fps,
"splits": {},
"data_path": data_path,
@@ -168,25 +170,25 @@ def episodes_stats_factory(stats_factory):
features: dict[str],
total_episodes: int = 3,
) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes):
episodes_stats[episode_index] = {
"episode_index": episode_index,
"stats": stats_factory(features),
}
return episodes_stats
def _generator(total_episodes):
for ep_idx in range(total_episodes):
flat_ep_stats = flatten_dict(stats_factory(features))
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
# Simpler to rely on generator instead of from_dict
return Dataset.from_generator(lambda: _generator(total_episodes))
return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
def _create_tasks(total_tasks: int = 3) -> Dataset:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
return Dataset.from_dict({"task_index": ids, "task": tasks})
return _create_tasks
@@ -196,6 +198,7 @@ def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
tasks: dict | None = None,
multi_task: bool = False,
):
@@ -215,26 +218,41 @@ def episodes_factory(tasks_factory):
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
num_tasks_available = len(tasks["task"])
episodes = {}
remaining_tasks = tasks_list.copy()
d = {
"episode_index": [],
"data/chunk_index": [],
"data/file_index": [],
"tasks": [],
"length": [],
}
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
remaining_tasks = tasks["task"].copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes[ep_idx] = {
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
d["episode_index"].append(ep_idx)
# TODO(rcadene): remove heuristic of only one file
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
return episodes
return Dataset.from_dict(d)
return _create_episodes
@@ -258,7 +276,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
@@ -291,7 +309,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
},
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
dataset.set_format("torch")
return dataset
return _create_hf_dataset
@@ -326,8 +344,9 @@ def lerobot_dataset_metadata_factory(
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
@@ -371,9 +390,9 @@ def lerobot_dataset_factory(
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
episode_dicts: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
@@ -388,9 +407,11 @@ def lerobot_dataset_factory(
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
multi_task=multi_task,
)