Fix tests
This commit is contained in:
@@ -3,12 +3,13 @@ import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.defaults import DUMMY_KEYS
|
||||
from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -53,7 +54,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict:
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
@@ -63,7 +64,7 @@ def valid_delta_timestamps_factory():
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
@@ -77,7 +78,7 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
@@ -90,14 +91,15 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list = DUMMY_KEYS) -> dict:
|
||||
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
@@ -107,10 +109,11 @@ def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_in
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
@@ -120,10 +123,11 @@ def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_dat
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
@@ -134,10 +138,11 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
|
||||
Reference in New Issue
Block a user