Most unit tests are passing
This commit is contained in:
41
tests/fixtures/files.py
vendored
41
tests/fixtures/files.py
vendored
@@ -1,17 +1,13 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
write_episodes,
|
||||
write_episodes_stats,
|
||||
write_hf_dataset,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -22,7 +18,7 @@ from lerobot.common.datasets.utils import (
|
||||
@pytest.fixture(scope="session")
|
||||
def create_info(info_factory):
|
||||
def _create_info(dir: Path, info: dict | None = None):
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
write_info(info, dir)
|
||||
|
||||
@@ -32,27 +28,27 @@ def create_info(info_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def create_stats(stats_factory):
|
||||
def _create_stats(dir: Path, stats: dict | None = None):
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory()
|
||||
write_stats(stats, dir)
|
||||
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_episodes_stats(episodes_stats_factory):
|
||||
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
write_episodes_stats(episodes_stats, dir)
|
||||
# @pytest.fixture(scope="session")
|
||||
# def create_episodes_stats(episodes_stats_factory):
|
||||
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||
# if episodes_stats is None:
|
||||
# episodes_stats = episodes_stats_factory()
|
||||
# write_episodes_stats(episodes_stats, dir)
|
||||
|
||||
return _create_episodes_stats
|
||||
# return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_tasks(tasks_factory):
|
||||
def _create_tasks(dir: Path, tasks: Dataset | None = None):
|
||||
if not tasks:
|
||||
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
write_tasks(tasks, dir)
|
||||
|
||||
@@ -61,17 +57,18 @@ def create_tasks(tasks_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: Dataset | None = None):
|
||||
if not episodes:
|
||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||
if episodes is None:
|
||||
episodes = episodes_factory()
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_hf_dataset(hf_dataset_factory):
|
||||
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
|
||||
if not hf_dataset:
|
||||
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
write_hf_dataset(hf_dataset, dir)
|
||||
|
||||
@@ -84,7 +81,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
@@ -108,7 +105,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
Reference in New Issue
Block a user