Refactor the download and publication of the datasets and convert it into CLI script (#95)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Adil Zouitine
2024-04-29 00:08:17 +02:00
committed by GitHub
parent 81e490d46f
commit 55dc9f7f51
15 changed files with 1410 additions and 827 deletions

View File

@@ -241,57 +241,65 @@ def test_flatten_unflatten_dict():
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
repo_id = "lerobot/pusht"
all_repo_id = [
"lerobot/pusht",
# TODO (azouitine): Add artifacts for the following datasets
# "lerobot/aloha_sim_insertion_human",
# "lerobot/xarm_push_medium",
# "lerobot/umi_cup_in_the_wild",
]
for repo_id in all_repo_id:
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
def load_and_compare(i):
new_frame = dataset[i] # noqa: B023
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame:
assert (
new_frame[key] == old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
for key in new_frame:
assert (
new_frame[key] == old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int(
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i+1)
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i+1)
# #test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# #test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)