Loads episode_data_index and stats during dataset __init__ (#85)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-23 14:13:25 +02:00
committed by GitHub
parent e2168163cd
commit 1030ea0070
89 changed files with 1008 additions and 432 deletions

View File

@@ -1,20 +1,26 @@
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode_id", 0, True),
("frame_id", 0, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
@@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name):
("next.done", 0, False),
]
for key in image_keys:
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@@ -94,26 +94,21 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
from lerobot.common.datasets.xarm import XarmDataset
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=data_dir,
transform=transform,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
@@ -122,18 +117,19 @@ def test_compute_stats_on_xarm():
batch_size=len(dataset),
shuffle=False,
)
hf_dataset = next(iter(dataloader))
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
@@ -142,11 +138,10 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
@@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset_id = "pusht"
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame:
assert (
new_frame[key] == old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i+1)
# #test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)