id -> index, finish moving compute_stats before hf_dataset push_to_hub
This commit is contained in:
@@ -1,22 +1,21 @@
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
get_stats_einops_patterns,
|
||||
load_previous_and_future_frames,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.transforms import Prod
|
||||
@@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode_id", 0, True),
|
||||
("frame_id", 0, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
@@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance():
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_id": [0, 0, 0, 0, 0],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
@@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_id": [0, 0, 0, 0, 0],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
@@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_id": [0, 0, 0, 0, 0],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
@@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
|
||||
Reference in New Issue
Block a user