Refactor push_dataset_to_hub (#118)
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -11,6 +11,7 @@ Example usage:
|
||||
`python tests/script/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,54 +22,56 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
data_dir = Path(output_dir) / repo_id
|
||||
repo_dir = Path(output_dir) / repo_id
|
||||
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
if repo_dir.exists():
|
||||
shutil.rmtree(repo_dir)
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=data_dir)
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
|
||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
available_datasets = [
|
||||
"lerobot/pusht",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/umi_cup_in_the_wild",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset in available_datasets:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
|
||||
@@ -12,7 +12,9 @@ from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
@@ -22,8 +24,7 @@ from lerobot.common.datasets.utils import (
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
@@ -238,68 +239,66 @@ def test_flatten_unflatten_dict():
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
all_repo_id = [
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/pusht",
|
||||
# TODO (azouitine): Add artifacts for the following datasets
|
||||
# "lerobot/aloha_sim_insertion_human",
|
||||
# "lerobot/xarm_push_medium",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for repo_id in all_repo_id:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
for key in new_frame:
|
||||
assert torch.isclose(
|
||||
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i + 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
# # test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
Reference in New Issue
Block a user