id -> index, finish moving compute_stats before hf_dataset push_to_hub

This commit is contained in:
Cadene
2024-04-19 10:33:42 +00:00
parent 64b09ea7a7
commit 714a776277
9 changed files with 120 additions and 99 deletions

View File

@@ -1,22 +1,21 @@
import logging
from copy import deepcopy
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
load_previous_and_future_frames,
flatten_dict,
unflatten_dict,
)
from lerobot.common.transforms import Prod
@@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode_id", 0, True),
("frame_id", 0, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
@@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
@@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
@@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
@@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {