Refactor push_dataset_to_hub (#118)
This commit is contained in:
@@ -12,7 +12,9 @@ from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
@@ -22,8 +24,7 @@ from lerobot.common.datasets.utils import (
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
@@ -238,68 +239,66 @@ def test_flatten_unflatten_dict():
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
all_repo_id = [
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/pusht",
|
||||
# TODO (azouitine): Add artifacts for the following datasets
|
||||
# "lerobot/aloha_sim_insertion_human",
|
||||
# "lerobot/xarm_push_medium",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for repo_id in all_repo_id:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
for key in new_frame:
|
||||
assert torch.isclose(
|
||||
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i + 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
# # test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
Reference in New Issue
Block a user