Add test_delta_timestamps.py
This commit is contained in:
@@ -202,7 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
@@ -740,7 +740,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
|
||||
@@ -265,7 +265,9 @@ def create_empty_dataset_info(
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
@@ -289,8 +291,6 @@ def check_timestamps_sync(
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
# timestamps[2] += tolerance_s # TODO delete
|
||||
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
@@ -339,7 +339,7 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) <= tolerance_s for ts in delta_ts]
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
|
||||
261
tests/test_delta_timestamps.py
Normal file
261
tests/test_delta_timestamps.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_hf_dataset_factory(hf_dataset_factory, episode_dicts, tasks):
|
||||
def _create_synced_hf_dataset(fps: int = 30, keys: list | None = None) -> Dataset:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
shapes = {key: 10 for key in keys}
|
||||
return hf_dataset_factory(episode_dicts, tasks, keys, shapes, fps=fps)
|
||||
|
||||
return _create_synced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just outside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_unsynced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just inside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_slightly_off_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list | None = None) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
delta_timestamps[key][3] += tolerance_s * 1.1
|
||||
return delta_timestamps
|
||||
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
delta_timestamps[key][3] += tolerance_s * 0.9
|
||||
delta_timestamps[key][-3] += tolerance_s * 0.9
|
||||
return delta_timestamps
|
||||
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list | None = None) -> dict:
|
||||
if not keys:
|
||||
keys = ["state", "action"]
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
|
||||
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=single_timestamp_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
# TODO(aliberts): change behavior of hf_transform_to_torch so that it can work with empty dataset
|
||||
# def test_check_timestamps_sync_empty_dataset():
|
||||
# fps = 30
|
||||
# tolerance_s = 1e-4
|
||||
# empty_hf_dataset = Dataset.from_dict({'timestamp': [], 'episode_index': []})
|
||||
# empty_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
# episode_data_index = {'to': torch.tensor([], dtype=torch.int64), 'from': torch.tensor([], dtype=torch.int64)}
|
||||
# result = check_timestamps_sync(
|
||||
# hf_dataset=empty_hf_dataset,
|
||||
# episode_data_index=episode_data_index,
|
||||
# fps=fps,
|
||||
# tolerance_s=tolerance_s,
|
||||
# )
|
||||
# assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=valid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_delta_timestamps_empty():
|
||||
delta_timestamps = {}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
|
||||
fps = 30
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
expected_delta_indices = delta_indices
|
||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
||||
assert expected_delta_indices == actual_delta_indices
|
||||
Reference in New Issue
Block a user