fix load_data_with_delta_timestamps and add tests
This commit is contained in:
@@ -4,7 +4,7 @@ import einops
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns
|
||||
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, is_contiguously_true_or_false, load_data_with_delta_timestamps
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
from lerobot.common.transforms import Prod
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
@@ -142,3 +142,49 @@ def test_compute_stats():
|
||||
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
|
||||
def test_load_data_with_delta_timestamps_within_tolerance():
|
||||
data_dict = {
|
||||
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||
}
|
||||
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.24]}
|
||||
key = "index"
|
||||
current_ts = 0.3
|
||||
episode = 0
|
||||
tol = 0.04
|
||||
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
assert torch.equal(data, torch.tensor([0, 2, 4])), "Data does not match expected values"
|
||||
|
||||
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||
data_dict = {
|
||||
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||
}
|
||||
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.14, 0.2]}
|
||||
key = "index"
|
||||
current_ts = 0.3
|
||||
episode = 0
|
||||
tol = 0.03
|
||||
with pytest.raises(AssertionError):
|
||||
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||
|
||||
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
data_dict = {
|
||||
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||
}
|
||||
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
key = "index"
|
||||
current_ts = 0.3
|
||||
episode = 0
|
||||
tol = 0.04
|
||||
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user