Most unit tests are passing

This commit is contained in:
Remi Cadene
2025-04-11 14:04:22 +02:00
parent c1b28f0b58
commit 34c5d4ce07
6 changed files with 391 additions and 322 deletions

View File

@@ -1,17 +1,13 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pandas as pd
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
from lerobot.common.datasets.utils import (
write_episodes,
write_episodes_stats,
write_hf_dataset,
write_info,
write_stats,
@@ -22,7 +18,7 @@ from lerobot.common.datasets.utils import (
@pytest.fixture(scope="session")
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if not info:
if info is None:
info = info_factory()
write_info(info, dir)
@@ -32,27 +28,27 @@ def create_info(info_factory):
@pytest.fixture(scope="session")
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if not stats:
if stats is None:
stats = stats_factory()
write_stats(stats, dir)
return _create_stats
@pytest.fixture(scope="session")
def create_episodes_stats(episodes_stats_factory):
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
if not episodes_stats:
episodes_stats = episodes_stats_factory()
write_episodes_stats(episodes_stats, dir)
# @pytest.fixture(scope="session")
# def create_episodes_stats(episodes_stats_factory):
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
# if episodes_stats is None:
# episodes_stats = episodes_stats_factory()
# write_episodes_stats(episodes_stats, dir)
return _create_episodes_stats
# return _create_episodes_stats
@pytest.fixture(scope="session")
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: Dataset | None = None):
if not tasks:
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
if tasks is None:
tasks = tasks_factory()
write_tasks(tasks, dir)
@@ -61,17 +57,18 @@ def create_tasks(tasks_factory):
@pytest.fixture(scope="session")
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: Dataset | None = None):
if not episodes:
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
episodes = episodes_factory()
write_episodes(episodes, dir)
return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
if not hf_dataset:
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
@@ -84,7 +81,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
@@ -108,7 +105,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()