From e69f0c50596d89b77dc4eff2c7a912195c517eb6 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 31 Oct 2024 13:48:40 +0100 Subject: [PATCH] Add test_delta_timestamps.py --- lerobot/common/datasets/lerobot_dataset.py | 4 +- lerobot/common/datasets/utils.py | 8 +- tests/test_delta_timestamps.py | 261 +++++++++++++++++++++ 3 files changed, 267 insertions(+), 6 deletions(-) create mode 100644 tests/test_delta_timestamps.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 4a48d51d..9af0b03c 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -202,7 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Load actual data self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) + self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes) # Check timestamps check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) @@ -740,7 +740,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) + self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) if len(self.video_keys) > 0: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 6d941ecf..e5cc02f9 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -265,7 +265,9 @@ def create_empty_dataset_info( } -def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]: +def get_episode_data_index( + episode_dicts: list[dict], episodes: list[int] | None = None +) -> dict[str, torch.Tensor]: episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -289,8 +291,6 @@ def check_timestamps_sync( account for possible numerical error. """ timestamps = torch.stack(hf_dataset["timestamp"]) - # timestamps[2] += tolerance_s # TODO delete - # timestamps[-2] += tolerance_s/2 # TODO delete diffs = torch.diff(timestamps) within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s @@ -339,7 +339,7 @@ def check_delta_timestamps( """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) <= tolerance_s for ts in delta_ts] + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] if not all(within_tolerance): outside_tolerance[key] = [ ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py new file mode 100644 index 00000000..29935fe4 --- /dev/null +++ b/tests/test_delta_timestamps.py @@ -0,0 +1,261 @@ +import pytest +import torch +from datasets import Dataset + +from lerobot.common.datasets.utils import ( + check_delta_timestamps, + check_timestamps_sync, + get_delta_indices, + hf_transform_to_torch, +) + + +@pytest.fixture(scope="module") +def synced_hf_dataset_factory(hf_dataset_factory, episode_dicts, tasks): + def _create_synced_hf_dataset(fps: int = 30, keys: list | None = None) -> Dataset: + if not keys: + keys = ["state", "action"] + shapes = {key: 10 for key in keys} + return hf_dataset_factory(episode_dicts, tasks, keys, shapes, fps=fps) + + return _create_synced_hf_dataset + + +@pytest.fixture(scope="module") +def unsynced_hf_dataset_factory(synced_hf_dataset_factory): + def _create_unsynced_hf_dataset( + fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + ) -> Dataset: + hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys) + features = hf_dataset.features + df = hf_dataset.to_pandas() + dtype = df["timestamp"].dtype # This is to avoid pandas type warning + # Modify a single timestamp just outside tolerance + df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1)) + unsynced_hf_dataset = Dataset.from_pandas(df, features=features) + unsynced_hf_dataset.set_transform(hf_transform_to_torch) + return unsynced_hf_dataset + + return _create_unsynced_hf_dataset + + +@pytest.fixture(scope="module") +def slightly_off_hf_dataset_factory(synced_hf_dataset_factory): + def _create_slightly_off_hf_dataset( + fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + ) -> Dataset: + hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys) + features = hf_dataset.features + df = hf_dataset.to_pandas() + dtype = df["timestamp"].dtype # This is to avoid pandas type warning + # Modify a single timestamp just inside tolerance + df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9)) + unsynced_hf_dataset = Dataset.from_pandas(df, features=features) + unsynced_hf_dataset.set_transform(hf_transform_to_torch) + return unsynced_hf_dataset + + return _create_slightly_off_hf_dataset + + +@pytest.fixture(scope="module") +def valid_delta_timestamps_factory(): + def _create_valid_delta_timestamps(fps: int = 30, keys: list | None = None) -> dict: + if not keys: + keys = ["state", "action"] + delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys} + return delta_timestamps + + return _create_valid_delta_timestamps + + +@pytest.fixture(scope="module") +def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): + def _create_invalid_delta_timestamps( + fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + ) -> dict: + if not keys: + keys = ["state", "action"] + delta_timestamps = valid_delta_timestamps_factory(fps, keys) + # Modify a single timestamp just outside tolerance + for key in keys: + delta_timestamps[key][3] += tolerance_s * 1.1 + return delta_timestamps + + return _create_invalid_delta_timestamps + + +@pytest.fixture(scope="module") +def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): + def _create_slightly_off_delta_timestamps( + fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + ) -> dict: + if not keys: + keys = ["state", "action"] + delta_timestamps = valid_delta_timestamps_factory(fps, keys) + # Modify a single timestamp just inside tolerance + for key in delta_timestamps: + delta_timestamps[key][3] += tolerance_s * 0.9 + delta_timestamps[key][-3] += tolerance_s * 0.9 + return delta_timestamps + + return _create_slightly_off_delta_timestamps + + +@pytest.fixture(scope="module") +def delta_indices(keys: list | None = None) -> dict: + if not keys: + keys = ["state", "action"] + return {key: list(range(-10, 10)) for key in keys} + + +def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index): + fps = 30 + tolerance_s = 1e-4 + synced_hf_dataset = synced_hf_dataset_factory(fps) + result = check_timestamps_sync( + hf_dataset=synced_hf_dataset, + episode_data_index=episode_data_index, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index): + fps = 30 + tolerance_s = 1e-4 + unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) + with pytest.raises(ValueError): + check_timestamps_sync( + hf_dataset=unsynced_hf_dataset, + episode_data_index=episode_data_index, + fps=fps, + tolerance_s=tolerance_s, + ) + + +def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index): + fps = 30 + tolerance_s = 1e-4 + unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) + result = check_timestamps_sync( + hf_dataset=unsynced_hf_dataset, + episode_data_index=episode_data_index, + fps=fps, + tolerance_s=tolerance_s, + raise_value_error=False, + ) + assert result is False + + +def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index): + fps = 30 + tolerance_s = 1e-4 + slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s) + result = check_timestamps_sync( + hf_dataset=slightly_off_hf_dataset, + episode_data_index=episode_data_index, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +def test_check_timestamps_sync_single_timestamp(): + single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]}) + single_timestamp_hf_dataset.set_transform(hf_transform_to_torch) + episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])} + fps = 30 + tolerance_s = 1e-4 + result = check_timestamps_sync( + hf_dataset=single_timestamp_hf_dataset, + episode_data_index=episode_data_index, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +# TODO(aliberts): change behavior of hf_transform_to_torch so that it can work with empty dataset +# def test_check_timestamps_sync_empty_dataset(): +# fps = 30 +# tolerance_s = 1e-4 +# empty_hf_dataset = Dataset.from_dict({'timestamp': [], 'episode_index': []}) +# empty_hf_dataset.set_transform(hf_transform_to_torch) +# episode_data_index = {'to': torch.tensor([], dtype=torch.int64), 'from': torch.tensor([], dtype=torch.int64)} +# result = check_timestamps_sync( +# hf_dataset=empty_hf_dataset, +# episode_data_index=episode_data_index, +# fps=fps, +# tolerance_s=tolerance_s, +# ) +# assert result is True + + +def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): + fps = 30 + tolerance_s = 1e-4 + valid_delta_timestamps = valid_delta_timestamps_factory(fps) + result = check_delta_timestamps( + delta_timestamps=valid_delta_timestamps, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory): + fps = 30 + tolerance_s = 1e-4 + slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s) + result = check_delta_timestamps( + delta_timestamps=slightly_off_delta_timestamps, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory): + fps = 30 + tolerance_s = 1e-4 + invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s) + with pytest.raises(ValueError): + check_delta_timestamps( + delta_timestamps=invalid_delta_timestamps, + fps=fps, + tolerance_s=tolerance_s, + ) + + +def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory): + fps = 30 + tolerance_s = 1e-4 + invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s) + result = check_delta_timestamps( + delta_timestamps=invalid_delta_timestamps, + fps=fps, + tolerance_s=tolerance_s, + raise_value_error=False, + ) + assert result is False + + +def test_check_delta_timestamps_empty(): + delta_timestamps = {} + fps = 30 + tolerance_s = 1e-4 + result = check_delta_timestamps( + delta_timestamps=delta_timestamps, + fps=fps, + tolerance_s=tolerance_s, + ) + assert result is True + + +def test_delta_indices(valid_delta_timestamps_factory, delta_indices): + fps = 30 + delta_timestamps = valid_delta_timestamps_factory(fps) + expected_delta_indices = delta_indices + actual_delta_indices = get_delta_indices(delta_timestamps, fps) + assert expected_delta_indices == actual_delta_indices