forked from tangger/lerobot
Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Binary file not shown.
3
tests/data/aloha_sim_insertion_human/meta_data/info.json
Normal file
3
tests/data/aloha_sim_insertion_human/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 50
|
||||
}
|
||||
BIN
tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors
Normal file
BIN
tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors
Normal file
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -37,14 +37,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "d79cf82ffc86f110",
|
||||
"_fingerprint": "22eeca7a3f4725ee",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 50
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -37,14 +37,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "d8e4a817b5449498",
|
||||
"_fingerprint": "97c28d4ad1536e4c",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 50
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -37,14 +37,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f03482befa767127",
|
||||
"_fingerprint": "cb9349b5c92951e8",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 50
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -37,14 +37,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "93e03c6320c7d56e",
|
||||
"_fingerprint": "e4d7ad2b360db1af",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
BIN
tests/data/pusht/meta_data/episode_data_index.safetensors
Normal file
BIN
tests/data/pusht/meta_data/episode_data_index.safetensors
Normal file
Binary file not shown.
3
tests/data/pusht/meta_data/info.json
Normal file
3
tests/data/pusht/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 10
|
||||
}
|
||||
BIN
tests/data/pusht/meta_data/stats.safetensors
Normal file
BIN
tests/data/pusht/meta_data/stats.safetensors
Normal file
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 2,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -45,14 +45,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
BIN
tests/data/pusht/train/meta_data/episode_data_index.safetensors
Normal file
BIN
tests/data/pusht/train/meta_data/episode_data_index.safetensors
Normal file
Binary file not shown.
3
tests/data/pusht/train/meta_data/info.json
Normal file
3
tests/data/pusht/train/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 10
|
||||
}
|
||||
BIN
tests/data/pusht/train/meta_data/stats_action.safetensors
Normal file
BIN
tests/data/pusht/train/meta_data/stats_action.safetensors
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "21bb9a76ed78a475",
|
||||
"_fingerprint": "a04a9ce660122e23",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
BIN
tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors
Normal file
BIN
tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors
Normal file
Binary file not shown.
BIN
tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors
Normal file
BIN
tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
3
tests/data/xarm_lift_medium/meta_data/info.json
Normal file
3
tests/data/xarm_lift_medium/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 15
|
||||
}
|
||||
BIN
tests/data/xarm_lift_medium/meta_data/stats.safetensors
Normal file
BIN
tests/data/xarm_lift_medium/meta_data/stats.safetensors
Normal file
Binary file not shown.
Binary file not shown.
@@ -21,11 +21,11 @@
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
@@ -41,14 +41,6 @@
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "a95cbec45e3bb9d6",
|
||||
"_fingerprint": "cc6afdfcdd6f63ab",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
||||
Binary file not shown.
3
tests/data/xarm_lift_medium_replay/meta_data/info.json
Normal file
3
tests/data/xarm_lift_medium_replay/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 15
|
||||
}
|
||||
BIN
tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors
Normal file
BIN
tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors
Normal file
Binary file not shown.
Binary file not shown.
51
tests/data/xarm_lift_medium_replay/train/dataset_info.json
Normal file
51
tests/data/xarm_lift_medium_replay/train/dataset_info.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/xarm_lift_medium_replay/train/state.json
Normal file
13
tests/data/xarm_lift_medium_replay/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "9f8e1a8c1845df55",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
3
tests/data/xarm_push_medium/meta_data/info.json
Normal file
3
tests/data/xarm_push_medium/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 15
|
||||
}
|
||||
BIN
tests/data/xarm_push_medium/meta_data/stats.safetensors
Normal file
BIN
tests/data/xarm_push_medium/meta_data/stats.safetensors
Normal file
Binary file not shown.
BIN
tests/data/xarm_push_medium/train/data-00000-of-00001.arrow
Normal file
BIN
tests/data/xarm_push_medium/train/data-00000-of-00001.arrow
Normal file
Binary file not shown.
51
tests/data/xarm_push_medium/train/dataset_info.json
Normal file
51
tests/data/xarm_push_medium/train/dataset_info.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 3,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/xarm_push_medium/train/state.json
Normal file
13
tests/data/xarm_push_medium/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "c900258061dd0b3f",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
3
tests/data/xarm_push_medium_replay/meta_data/info.json
Normal file
3
tests/data/xarm_push_medium_replay/meta_data/info.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"fps": 15
|
||||
}
|
||||
BIN
tests/data/xarm_push_medium_replay/meta_data/stats.safetensors
Normal file
BIN
tests/data/xarm_push_medium_replay/meta_data/stats.safetensors
Normal file
Binary file not shown.
Binary file not shown.
51
tests/data/xarm_push_medium_replay/train/dataset_info.json
Normal file
51
tests/data/xarm_push_medium_replay/train/dataset_info.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 3,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/xarm_push_medium_replay/train/state.json
Normal file
13
tests/data/xarm_push_medium_replay/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "e51c80a33c7688c0",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
71
tests/scripts/save_dataset_to_safetensors.py
Normal file
71
tests/scripts/save_dataset_to_safetensors.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
||||
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
||||
dataset into a corresponding safetensors file in a specified output directory.
|
||||
|
||||
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
|
||||
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
|
||||
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
||||
|
||||
Example usage:
|
||||
`python tests/script/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
|
||||
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
|
||||
data_dir = Path(output_dir) / dataset_id
|
||||
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||
dataset = PushtDataset(
|
||||
dataset_id=dataset_id,
|
||||
split="train",
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors")
|
||||
@@ -1,20 +1,26 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
get_stats_einops_patterns,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.transforms import Prod
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
@@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode_id", 0, True),
|
||||
("frame_id", 0, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
@@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||
("next.done", 0, False),
|
||||
]
|
||||
|
||||
for key in image_keys:
|
||||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
if key not in item:
|
||||
@@ -94,26 +94,21 @@ def test_compute_stats_on_xarm():
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=data_dir,
|
||||
transform=transform,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
||||
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
|
||||
|
||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -122,18 +117,19 @@ def test_compute_stats_on_xarm():
|
||||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
hf_dataset = next(iter(dataloader))
|
||||
full_batch = next(iter(dataloader))
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
for k, pattern in stats_patterns.items():
|
||||
full_batch[k] = full_batch[k].float()
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
||||
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(
|
||||
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
)
|
||||
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
||||
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||
|
||||
# test computed stats match expected stats
|
||||
for k in stats_patterns:
|
||||
@@ -142,11 +138,10 @@ def test_compute_stats_on_xarm():
|
||||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# TODO(rcadene): check that the stats used for training are correct too
|
||||
# # load stats that are expected to match the ones returned by computed_stats
|
||||
# assert (dataset.data_dir / "stats.pth").exists()
|
||||
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
# for k in stats_patterns:
|
||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
@@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance():
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
@@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
with pytest.raises(AssertionError):
|
||||
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||
@@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||
dataset_id = "pusht"
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
|
||||
|
||||
dataset = PushtDataset(
|
||||
dataset_id=dataset_id,
|
||||
split="train",
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i]
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# TODO(aliberts): Mute logging for these tests
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
31
tests/test_visualize_dataset.py
Normal file
31
tests/test_visualize_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id",
|
||||
[
|
||||
"aloha_sim_insertion_human",
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, dataset_id):
|
||||
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset_id={dataset_id}",
|
||||
],
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
assert len(video_paths) > 0
|
||||
|
||||
for video_path in video_paths:
|
||||
assert video_path.exists()
|
||||
Reference in New Issue
Block a user