WIP add load functions + episode_data_index

This commit is contained in:
Cadene
2024-04-18 23:54:52 +00:00
parent 0bd2ca8d82
commit 64b09ea7a7
4 changed files with 159 additions and 40 deletions

View File

@@ -1,10 +1,13 @@
import logging
from copy import deepcopy
import json
import os
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
import lerobot
@@ -13,6 +16,8 @@ from lerobot.common.datasets.utils import (
compute_stats,
get_stats_einops_patterns,
load_previous_and_future_frames,
flatten_dict,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
@@ -160,15 +165,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@@ -179,16 +187,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@@ -196,17 +207,43 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"