Refactor the download and publication of the datasets and convert it into CLI script (#95)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -241,57 +241,65 @@ def test_flatten_unflatten_dict():
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
repo_id = "lerobot/pusht"
|
||||
all_repo_id = [
|
||||
"lerobot/pusht",
|
||||
# TODO (azouitine): Add artifacts for the following datasets
|
||||
# "lerobot/aloha_sim_insertion_human",
|
||||
# "lerobot/xarm_push_medium",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for repo_id in all_repo_id:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i]
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
Reference in New Issue
Block a user