forked from tangger/lerobot
Split fixtures into factories and files
This commit is contained in:
@@ -8,25 +8,21 @@ from lerobot.common.datasets.utils import (
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.defaults import DUMMY_KEYS
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_hf_dataset_factory(hf_dataset_factory, episode_dicts, tasks):
|
||||
def _create_synced_hf_dataset(fps: int = 30, keys: list | None = None) -> Dataset:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
shapes = {key: 10 for key in keys}
|
||||
return hf_dataset_factory(episode_dicts, tasks, keys, shapes, fps=fps)
|
||||
def synced_hf_dataset_factory(hf_dataset_factory):
|
||||
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
|
||||
return hf_dataset_factory(fps=fps)
|
||||
|
||||
return _create_synced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
||||
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
@@ -41,10 +37,8 @@ def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
||||
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
@@ -59,9 +53,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list | None = None) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
@@ -71,10 +63,8 @@ def valid_delta_timestamps_factory():
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
@@ -87,10 +77,8 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
@@ -102,9 +90,7 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list | None = None) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
def delta_indices(keys: list = DUMMY_KEYS) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user