Refactor push_dataset_to_hub (#118)

This commit is contained in:
Remi
2024-04-30 14:25:41 +02:00
committed by GitHub
parent 2765877f28
commit e4e739f4f8
25 changed files with 1089 additions and 1192 deletions

View File

@@ -12,7 +12,9 @@ from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
@@ -22,8 +24,7 @@ from lerobot.common.datasets.utils import (
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
@@ -238,68 +239,66 @@ def test_flatten_unflatten_dict():
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
all_repo_id = [
@pytest.mark.parametrize(
"repo_id",
[
"lerobot/pusht",
# TODO (azouitine): Add artifacts for the following datasets
# "lerobot/aloha_sim_insertion_human",
# "lerobot/xarm_push_medium",
# "lerobot/umi_cup_in_the_wild",
]
for repo_id in all_repo_id:
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
"lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium",
],
)
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
def load_and_compare(i):
new_frame = dataset[i] # noqa: B023
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
def load_and_compare(i):
new_frame = dataset[i] # noqa: B023
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
for key in new_frame:
assert (
new_frame[key] == old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
for key in new_frame:
assert torch.isclose(
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
).all(), f"{key=} for index={i} does not contain the same value"
# test 2 frames at the middle of first episode
i = int(
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
)
load_and_compare(i)
load_and_compare(i + 1)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i+1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# #test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i + 1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)