fix load_data_with_delta_timestamps and add tests

This commit is contained in:
Cadene
2024-04-11 12:59:09 +00:00
parent 9229226522
commit 657b27cc8f
2 changed files with 89 additions and 39 deletions

View File

@@ -4,7 +4,7 @@ import einops
import pytest
import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, is_contiguously_true_or_false, load_data_with_delta_timestamps
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
@@ -142,3 +142,49 @@ def test_compute_stats():
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_data_with_delta_timestamps_within_tolerance():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.24]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert not is_pad.any(), "Unexpected padding detected"
assert torch.equal(data, torch.tensor([0, 2, 4])), "Data does not match expected values"
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.14, 0.2]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.03
with pytest.raises(AssertionError):
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"