Add test_delta_timestamps.py

This commit is contained in:
Simon Alibert
2024-10-31 13:48:40 +01:00
parent ff84024ee9
commit e69f0c5059
3 changed files with 267 additions and 6 deletions

View File

@@ -265,7 +265,9 @@ def create_empty_dataset_info(
}
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -289,8 +291,6 @@ def check_timestamps_sync(
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
# timestamps[2] += tolerance_s # TODO delete
# timestamps[-2] += tolerance_s/2 # TODO delete
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
@@ -339,7 +339,7 @@ def check_delta_timestamps(
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) <= tolerance_s for ts in delta_ts]
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within