Commit before episodes episodes_stats merging

This commit is contained in:
Remi Cadene
2025-04-09 15:20:15 +02:00
parent 53ecec5fb2
commit c1b28f0b58
12 changed files with 905 additions and 396 deletions

View File

@@ -7,83 +7,75 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
write_episodes,
write_episodes_stats,
write_hf_dataset,
write_info,
write_stats,
write_tasks,
)
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
write_info(info, dir)
return _create_info_json_file
return _create_info
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
write_stats(stats, dir)
return _create_stats_json_file
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
def create_episodes_stats(episodes_stats_factory):
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
write_episodes_stats(episodes_stats, dir)
return _create_episodes_stats_jsonl_file
return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: Dataset | None = None):
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks.values())
return fpath
write_tasks(tasks, dir)
return _create_tasks_jsonl_file
return _create_tasks
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: Dataset | None = None):
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes.values())
return fpath
write_episodes(episodes, dir)
return _create_episodes_jsonl_file
return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
if not hf_dataset:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
return _create_hf_dataset
@pytest.fixture(scope="session")
@@ -91,6 +83,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
info = info_factory()
if hf_dataset is None:
@@ -114,6 +107,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
info = info_factory()
if hf_dataset is None: