diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index dfd664b..7b4a746 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -522,21 +522,21 @@ toml = ["tomli"] [[package]] name = "datasets" -version = "2.18.0" +version = "2.19.0" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"}, - {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"}, + {file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"}, + {file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"}, ] [package.dependencies] aiohttp = "*" dill = ">=0.3.0,<0.3.9" filelock = "*" -fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]} -huggingface-hub = ">=0.19.4" +fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]} +huggingface-hub = ">=0.21.2" multiprocess = "*" numpy = ">=1.17" packaging = "*" @@ -552,15 +552,15 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] -docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] -tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] -tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +tensorflow = ["tensorflow (>=2.6.0)"] +tensorflow-gpu = ["tensorflow (>=2.6.0)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -1524,7 +1524,6 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li optional = true python-versions = ">=3.6" files = [ - {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, @@ -1534,7 +1533,6 @@ files = [ {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, @@ -1544,7 +1542,6 @@ files = [ {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, @@ -1570,8 +1567,8 @@ files = [ {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, @@ -1579,7 +1576,6 @@ files = [ {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, @@ -2688,7 +2684,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3919,4 +3914,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e" +content-hash = "e526416d1282dea2550680b2be7fcf9ff6e1c67ac89d34c684b486d94a6addee" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index b13a9e9..bea7085 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -53,7 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true} debugpy = {version = "^1.8.1", optional = true} pytest = {version = "^8.1.0", optional = true} pytest-cov = {version = "^5.0.0", optional = true} -datasets = "^2.18.0" +datasets = "^2.19.0" [tool.poetry.extras] diff --git a/README.md b/README.md index 202b90e..a0045bf 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS You will need to set the corresponding version as a default argument in your dataset class: ```python - version: str | None = "v1.0", + version: str | None = "v1.1", ``` See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index d0d3577..8e1e27c 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -4,6 +4,7 @@ useless dependencies when using datasets. """ import io +import json import pickle import shutil from pathlib import Path @@ -14,16 +15,20 @@ import numpy as np import torch import tqdm from datasets import Dataset, Features, Image, Sequence, Value +from huggingface_hub import HfApi from PIL import Image as PILImage +from safetensors.torch import save_file + +from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch -def download_and_upload(root, root_tests, dataset_id): +def download_and_upload(root, revision, dataset_id): if "pusht" in dataset_id: - download_and_upload_pusht(root, root_tests, dataset_id) + download_and_upload_pusht(root, revision, dataset_id) elif "xarm" in dataset_id: - download_and_upload_xarm(root, root_tests, dataset_id) + download_and_upload_xarm(root, revision, dataset_id) elif "aloha" in dataset_id: - download_and_upload_aloha(root, root_tests, dataset_id) + download_and_upload_aloha(root, revision, dataset_id) else: raise ValueError(dataset_id) @@ -56,7 +61,102 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return False -def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): +def concatenate_episodes(ep_dicts): + data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict + + +def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id): + # push to main to indicate latest version + hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) + + # push to version branch + hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision) + + # create and store meta_data + meta_data_dir = root / dataset_id / "meta_data" + meta_data_dir.mkdir(parents=True, exist_ok=True) + + api = HfApi() + + # info + info_path = meta_data_dir / "info.json" + with open(str(info_path), "w") as f: + json.dump(info, f, indent=4) + api.upload_file( + path_or_fileobj=info_path, + path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=info_path, + path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + # stats + stats_path = meta_data_dir / "stats.safetensors" + save_file(flatten_dict(stats), stats_path) + api.upload_file( + path_or_fileobj=stats_path, + path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=stats_path, + path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + # episode_data_index + episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index} + ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" + save_file(episode_data_index, ep_data_idx_path) + api.upload_file( + path_or_fileobj=ep_data_idx_path, + path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=ep_data_idx_path, + path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + # copy in tests folder, the first episode and the meta_data directory + num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] + hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk( + f"tests/data/{dataset_id}/train" + ) + if Path(f"tests/data/{dataset_id}/meta_data").exists(): + shutil.rmtree(f"tests/data/{dataset_id}/meta_data") + shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data") + + +def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): try: import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely @@ -99,6 +199,7 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): actions = torch.from_numpy(dataset_dict["action"]) ep_dicts = [] + episode_data_index = {"from": [], "to": []} id_from = 0 for episode_id in tqdm.tqdm(range(num_episodes)): @@ -151,8 +252,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": agent_pos, "action": actions[id_from:id_to], - "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.image": image[1:], # "next.observation.state": agent_pos[1:], @@ -160,28 +261,15 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]), - "episode_data_index_from": torch.tensor([id_from] * num_frames), - "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + id_from += num_frames - data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - data_dict[key].append(x) - - total_frames = id_from - data_dict["index"] = torch.arange(0, total_frames, 1) + data_dict = concatenate_episodes(ep_dicts) features = { "observation.image": Image(), @@ -189,35 +277,35 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None), "next.done": Value(dtype="bool", id=None), "next.success": Value(dtype="bool", id=None), "index": Value(dtype="int64", id=None), - "episode_data_index_from": Value(dtype="int64", id=None), - "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) - num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + info = { + "fps": fps, + } + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) -def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): +def download_and_upload_xarm(root, revision, dataset_id, fps=15): root = Path(root) - raw_dir = root / f"{dataset_id}_raw" + raw_dir = root / "xarm_datasets_raw" if not raw_dir.exists(): import zipfile import gdown raw_dir.mkdir(parents=True, exist_ok=True) + # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" zip_path = raw_dir / "data.zip" gdown.download(url, str(zip_path), quiet=False) @@ -234,13 +322,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) - total_frames = dataset_dict["actions"].shape[0] - ep_dicts = [] + episode_data_index = {"from": [], "to": []} id_from = 0 id_to = 0 episode_id = 0 + total_frames = dataset_dict["actions"].shape[0] for i in tqdm.tqdm(range(total_frames)): id_to += 1 @@ -264,35 +352,23 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": state, "action": action, - "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.image": next_image, # "next.observation.state": next_state, "next.reward": next_reward, "next.done": next_done, - "episode_data_index_from": torch.tensor([id_from] * num_frames), - "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + id_from = id_to episode_id += 1 - data_dict = {} - keys = ep_dicts[0].keys() - for key in keys: - if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - data_dict[key].append(x) - - total_frames = id_from - data_dict["index"] = torch.arange(0, total_frames, 1) + data_dict = concatenate_episodes(ep_dicts) features = { "observation.image": Image(), @@ -300,27 +376,26 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None), "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_index_from": Value(dtype="int64", id=None), - "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) - num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + info = { + "fps": fps, + } + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) -def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): +def download_and_upload_aloha(root, revision, dataset_id, fps=50): folder_urls = { "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", @@ -381,6 +456,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) ep_dicts = [] + episode_data_index = {"from": [], "to": []} id_from = 0 for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): @@ -408,40 +484,26 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): { "observation.state": state, "action": action, - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([ep_id] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.state": state, # TODO(rcadene): compute reward and success # "next.reward": reward, "next.done": done, # "next.success": success, - "episode_data_index_from": torch.tensor([id_from] * num_frames), - "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ) assert isinstance(ep_id, int) ep_dicts.append(ep_dict) + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + id_from += num_frames - data_dict = {} - - data_dict = {} - keys = ep_dicts[0].keys() - for key in keys: - if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - data_dict[key].append(x) - - total_frames = id_from - data_dict["index"] = torch.arange(0, total_frames, 1) + data_dict = concatenate_episodes(ep_dicts) features = { "observation.images.top": Image(), @@ -449,39 +511,39 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), #'next.reward': Value(dtype='float32', id=None), "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_index_from": Value(dtype="int64", id=None), - "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) - num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + info = { + "fps": fps, + } + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) if __name__ == "__main__": root = "data" - root_tests = "tests/data" + revision = "v1.1" dataset_ids = [ - # "pusht", - # "xarm_lift_medium", - # "aloha_sim_insertion_human", - # "aloha_sim_insertion_scripted", - # "aloha_sim_transfer_cube_human", + "pusht", + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", "aloha_sim_transfer_cube_scripted", ] for dataset_id in dataset_ids: - download_and_upload(root, root_tests, dataset_id) - # assume stats have been precomputed - shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") + download_and_upload(root, revision, dataset_id) diff --git a/examples/1_load_hugging_face_dataset.py b/examples/1_load_hugging_face_dataset.py index 17d2891..d249394 100644 --- a/examples/1_load_hugging_face_dataset.py +++ b/examples/1_load_hugging_face_dataset.py @@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset This script supports several Hugging Face datasets, among which: 1. [Pusht](https://huggingface.co/datasets/lerobot/pusht) 2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium) -3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) -4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) -5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) -6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) +3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay) +4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium) +5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay) +6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) +7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) +8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) +9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) To try a different Hugging Face dataset, you can replace this line: ```python @@ -22,12 +25,16 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 by one of these: ```python hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15 hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50 ``` """ +# TODO(rcadene): remove this example file of using hf_dataset from pathlib import Path @@ -37,19 +44,22 @@ from datasets import load_dataset # TODO(rcadene): list available datasets on lerobot page using `datasets` # download/load hugging face dataset in pyarrow format -hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10 +hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 # display name of dataset and its features +# TODO(rcadene): update to make the print pretty print(f"{hf_dataset=}") print(f"{hf_dataset.features=}") # display useful statistics about frames and episodes, which are sequences of frames from the same video print(f"number of frames: {len(hf_dataset)=}") -print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}") -print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}") +print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}") +print( + f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}" +) # select the frames belonging to episode number 5 -hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5) +hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5) # load all frames of episode 5 in RAM in PIL format frames = hf_dataset["observation.image"] diff --git a/examples/2_load_lerobot_dataset.py b/examples/2_load_lerobot_dataset.py index 49a53d8..4eaed23 100644 --- a/examples/2_load_lerobot_dataset.py +++ b/examples/2_load_lerobot_dataset.py @@ -18,7 +18,10 @@ dataset = PushtDataset() ``` by one of these: ```python -dataset = XarmDataset() +dataset = XarmDataset("xarm_lift_medium") +dataset = XarmDataset("xarm_lift_medium_replay") +dataset = XarmDataset("xarm_push_medium") +dataset = XarmDataset("xarm_push_medium_replay") dataset = AlohaDataset("aloha_sim_insertion_human") dataset = AlohaDataset("aloha_sim_insertion_scripted") dataset = AlohaDataset("aloha_sim_transfer_cube_human") @@ -44,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset dataset = PushtDataset() # All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). +# TODO(rcadene): update to make the print pretty print(f"{dataset=}") print(f"{dataset.hf_dataset=}") @@ -55,13 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}") print(f"keys to access images from cameras: {dataset.image_keys=}") # While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. -dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5) +# TODO(rcadene): remove this example of accessing hf_dataset +dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) -# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames. +# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames. frames = [sample["observation.image"] for sample in dataset] -# but frames are now channel first to follow pytorch convention, -# to view them, we convert to channel last +# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention, +# to view them, we convert to uint8 range [0,255] +frames = [(frame * 255).type(torch.uint8) for frame in frames] +# and to channel last (h,w,c) frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] # and finally save them to a mp4 video diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 83e51c7..70d7d7b 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -50,7 +50,12 @@ available_datasets = { "aloha_sim_transfer_cube_scripted", ], "pusht": ["pusht"], - "xarm": ["xarm_lift_medium"], + "xarm": [ + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + ], } available_policies = [ diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 785b68e..f96d32b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,9 +1,13 @@ from pathlib import Path import torch -from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_previous_and_future_frames +from lerobot.common.datasets.utils import ( + load_episode_data_index, + load_hf_dataset, + load_previous_and_future_frames, + load_stats, +) class AlohaDataset(torch.utils.data.Dataset): @@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str, - version: str | None = "v1.0", + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, @@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset): self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - if self.root is not None: - self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split) - else: - self.hf_dataset = load_dataset( - f"lerobot/{self.dataset_id}", revision=self.version, split=self.split - ) - self.hf_dataset = self.hf_dataset.with_format("torch") + # load data from hub or locally when root is provided + self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) + self.episode_data_index = load_episode_data_index(dataset_id, version, root) + self.stats = load_stats(dataset_id, version, root) @property def num_samples(self) -> int: @@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.hf_dataset.unique("episode_index")) def __len__(self): return self.num_samples @@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset): item = load_previous_and_future_frames( item, self.hf_dataset, + self.episode_data_index, self.delta_timestamps, tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) - # convert images from channel last (PIL) to channel first (pytorch) - for key in self.image_keys: - if item[key].ndim == 3: - item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w - elif item[key].ndim == 4: - item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w - else: - raise ValueError(item[key].ndim) - if self.transform is not None: item = self.transform(item) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 07afb61..0fbfff6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,12 +1,10 @@ -import logging import os from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.datasets.utils import compute_stats -from lerobot.common.transforms import NormalizeTransform, Prod +from lerobot.common.transforms import NormalizeTransform DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None @@ -52,32 +50,18 @@ def make_dataset( stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) elif stats_path is None: - # load stats if the file exists already or compute stats and save it - if DATA_DIR is None: - # TODO(rcadene): clean stats - precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" - else: - precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth" - if precomputed_stats_path.exists(): - stats = torch.load(precomputed_stats_path) - else: - logging.info(f"compute_stats and save to {precomputed_stats_path}") - # Create a dataset for stats computation. - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - split="train", - root=DATA_DIR, - transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), - ) - stats = compute_stats(stats_dataset) - precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(stats, precomputed_stats_path) + # load a first dataset to access precomputed stats + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + split="train", + root=DATA_DIR, + ) + stats = stats_dataset.stats else: stats = torch.load(stats_path) transforms = v2.Compose( [ - Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), NormalizeTransform( stats, in_keys=[ diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 2879c17..bc978b7 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,9 +1,13 @@ from pathlib import Path import torch -from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_previous_and_future_frames +from lerobot.common.datasets.utils import ( + load_episode_data_index, + load_hf_dataset, + load_previous_and_future_frames, + load_stats, +) class PushtDataset(torch.utils.data.Dataset): @@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str = "pusht", - version: str | None = "v1.0", + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, @@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset): self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - if self.root is not None: - self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split) - else: - self.hf_dataset = load_dataset( - f"lerobot/{self.dataset_id}", revision=self.version, split=self.split - ) - self.hf_dataset = self.hf_dataset.with_format("torch") + # load data from hub or locally when root is provided + self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) + self.episode_data_index = load_episode_data_index(dataset_id, version, root) + self.stats = load_stats(dataset_id, version, root) @property def num_samples(self) -> int: @@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.episode_data_index["from"]) def __len__(self): return self.num_samples @@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset): item = load_previous_and_future_frames( item, self.hf_dataset, + self.episode_data_index, self.delta_timestamps, tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) - # convert images from channel last (PIL) to channel first (pytorch) - for key in self.image_keys: - if item[key].ndim == 3: - item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w - elif item[key].ndim == 4: - item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w - else: - raise ValueError(item[key].ndim) - if self.transform is not None: item = self.transform(item) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 50c5085..f5246c7 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,15 +1,121 @@ from copy import deepcopy from math import ceil +from pathlib import Path import datasets import einops import torch import tqdm +from datasets import Image, load_dataset, load_from_disk +from huggingface_hub import hf_hub_download +from PIL import Image as PILImage +from safetensors.torch import load_file +from torchvision import transforms + + +def flatten_dict(d, parent_key="", sep="/"): + """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + + For example: + ``` + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` + >>> print(flatten_dict(dct)) + {"a/b": 1, "a/c/d": 2, "e": 3} + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dict(d, sep="/"): + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d = outdict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return outdict + + +def hf_transform_to_torch(items_dict): + """Get a transform function that convert items from Hugging Face dataset (pyarrow) + to torch tensors. Importantly, images are converted from PIL, which corresponds to + a channel last representation (h w c) of uint8 type, to a torch image representation + with channel first (c h w) of float32 type in range [0,1]. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + else: + items_dict[key] = [torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if root is not None: + hf_dataset = load_from_disk(str(Path(root) / dataset_id / split)) + else: + # TODO(rcadene): remove dataset_id everywhere and use repo_id instead + repo_id = f"lerobot/{dataset_id}" + hf_dataset = load_dataset(repo_id, revision=version, split=split) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + +def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]: + """episode_data_index contains the range of indices for each episode + + Example: + ```python + from_id = episode_data_index["from"][episode_id].item() + to_id = episode_data_index["to"][episode_id].item() + episode_frames = [dataset[i] for i in range(from_id, to_id)] + ``` + """ + if root is not None: + path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors" + else: + repo_id = f"lerobot/{dataset_id}" + path = hf_hub_download( + repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version + ) + + return load_file(path) + + +def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]: + """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std + + Example: + ```python + normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] + ``` + """ + if root is not None: + path = Path(root) / dataset_id / "meta_data" / "stats.safetensors" + else: + repo_id = f"lerobot/{dataset_id}" + path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) + + stats = load_file(path) + return unflatten_dict(stats) def load_previous_and_future_frames( item: dict[str, torch.Tensor], hf_dataset: datasets.Dataset, + episode_data_index: dict[str, torch.Tensor], delta_timestamps: dict[str, list[float]], tol: float, ) -> dict[torch.Tensor]: @@ -31,6 +137,8 @@ def load_previous_and_future_frames( corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. + They indicate the start index and end index of each episode in the dataset. - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps. - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query @@ -46,12 +154,14 @@ def load_previous_and_future_frames( issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps - ep_data_id_from = item["episode_data_index_from"].item() - ep_data_id_to = item["episode_data_index_to"].item() + ep_id = item["episode_index"].item() + ep_data_id_from = episode_data_index["from"][ep_id].item() + ep_data_id_to = episode_data_index["to"][ep_id].item() ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) # load timestamps ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] + ep_timestamps = torch.stack(ep_timestamps) # we make the assumption that the timestamps are sorted ep_first_ts = ep_timestamps[0] @@ -82,39 +192,57 @@ def load_previous_and_future_frames( # load frames modality item[key] = hf_dataset.select_columns(key)[data_ids][key] + item[key] = torch.stack(item[key]) item[f"{key}_is_pad"] = is_pad return item -def get_stats_einops_patterns(dataset): - """These einops patterns will be used to aggregate batches and compute statistics.""" - stats_patterns = { - "action": "b c -> c", - "observation.state": "b c -> c", - } - for key in dataset.image_keys: - stats_patterns[key] = "b c h w -> c 1 1" +def get_stats_einops_patterns(hf_dataset): + """These einops patterns will be used to aggregate batches and compute statistics. + + Note: We assume the images of `hf_dataset` are in channel first format + """ + + dataloader = torch.utils.data.DataLoader( + hf_dataset, + num_workers=0, + batch_size=2, + shuffle=False, + ) + batch = next(iter(dataloader)) + + stats_patterns = {} + for key, feats_type in hf_dataset.features.items(): + # sanity check that tensors are not float64 + assert batch[key].dtype != torch.float64 + + if isinstance(feats_type, Image): + # sanity check that images are channel first + _, c, h, w = batch[key].shape + assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" + + # sanity check that images are float32 in range [0,1] + assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" + assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" + assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + + stats_patterns[key] = "b c h w -> c 1 1" + elif batch[key].ndim == 2: + stats_patterns[key] = "b c -> c " + elif batch[key].ndim == 1: + stats_patterns[key] = "b -> 1" + else: + raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") + return stats_patterns -def compute_stats(dataset, batch_size=32, max_num_samples=None): +def compute_stats(hf_dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: - max_num_samples = len(dataset) - else: - raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") + max_num_samples = len(hf_dataset) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=4, - batch_size=batch_size, - shuffle=False, - # pin_memory=cfg.device != "cpu", - drop_last=False, - ) - - # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) + stats_patterns = get_stats_einops_patterns(hf_dataset) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} @@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() + def create_seeded_dataloader(hf_dataset, batch_size, seed): + generator = torch.Generator() + generator.manual_seed(seed) + dataloader = torch.utils.data.DataLoader( + hf_dataset, + num_workers=4, + batch_size=batch_size, + shuffle=True, + drop_last=False, + generator=generator, + ) + return dataloader + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None running_item_count = 0 # for online mean computation + dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") ): @@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): first_batch_ = None running_item_count = 0 # for online std computation + dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") ): diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 385b7d9..7e69e7d 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -1,25 +1,37 @@ from pathlib import Path import torch -from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_previous_and_future_frames +from lerobot.common.datasets.utils import ( + load_episode_data_index, + load_hf_dataset, + load_previous_and_future_frames, + load_stats, +) class XarmDataset(torch.utils.data.Dataset): """ https://huggingface.co/datasets/lerobot/xarm_lift_medium + https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay + https://huggingface.co/datasets/lerobot/xarm_push_medium + https://huggingface.co/datasets/lerobot/xarm_push_medium_replay """ # Copied from lerobot/__init__.py - available_datasets = ["xarm_lift_medium"] + available_datasets = [ + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + ] fps = 15 image_keys = ["observation.image"] def __init__( self, - dataset_id: str = "xarm_lift_medium", - version: str | None = "v1.0", + dataset_id: str, + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, @@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset): self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - if self.root is not None: - self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split) - else: - self.hf_dataset = load_dataset( - f"lerobot/{self.dataset_id}", revision=self.version, split=self.split - ) - self.hf_dataset = self.hf_dataset.with_format("torch") + # load data from hub or locally when root is provided + self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) + self.episode_data_index = load_episode_data_index(dataset_id, version, root) + self.stats = load_stats(dataset_id, version, root) @property def num_samples(self) -> int: @@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.hf_dataset.unique("episode_index")) def __len__(self): return self.num_samples @@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset): item = load_previous_and_future_frames( item, self.hf_dataset, + self.episode_data_index, self.delta_timestamps, tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) - # convert images from channel last (PIL) to channel first (pytorch) - for key in self.image_keys: - if item[key].ndim == 3: - item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w - elif item[key].ndim == 4: - item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w - else: - raise ValueError(item[key].ndim) - if self.transform is not None: item = self.transform(item) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d557193..24d69c3 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: for _ in range(num_parallel_envs) ] ) + return env diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 7f5216c..dcce1bc 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None): for imgkey, img in imgs.items(): img = torch.from_numpy(img) - # convert to (b c h w) torch format + + # sanity check that images are channel last + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel first images, but instead {img.shape}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] img = einops.rearrange(img, "b h w c -> b c h w") + img = img.type(torch.float32) + img /= 255 + obs[imgkey] = img # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index ec96761..fffa835 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -1,4 +1,3 @@ -import torch from torchvision.transforms.v2 import Compose, Transform @@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform): return item -class Prod(Transform): - invertible = True - - def __init__(self, in_keys: list[str], prod: float): - super().__init__() - self.in_keys = in_keys - self.prod = prod - self.original_dtypes = {} - - def forward(self, item): - for key in self.in_keys: - if key not in item: - continue - self.original_dtypes[key] = item[key].dtype - item[key] = item[key].type(torch.float32) * self.prod - return item - - def inverse_transform(self, item): - for key in self.in_keys: - if key not in item: - continue - item[key] = (item[key] / self.prod).type(self.original_dtypes[key]) - return item - - # def transform_observation_spec(self, obs_spec): - # for key in self.in_keys: - # if obs_spec.get(key, None) is None: - # continue - # obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod - # obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod - # obs_spec[key].dtype = torch.float32 - # return obs_spec - - class NormalizeTransform(Transform): invertible = True diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0c0e8e8..7b3c6dd 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -47,6 +47,7 @@ from PIL import Image as PILImage from tqdm import trange from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.logger import log_output_dir @@ -208,11 +209,12 @@ def eval_policy( max_rewards.extend(batch_max_reward.tolist()) all_successes.extend(batch_success.tolist()) - # similar logic is implemented in dataset preprocessing + # similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`) ep_dicts = [] + episode_data_index = {"from": [], "to": []} num_episodes = dones.shape[0] total_frames = 0 - idx_from = 0 + id_from = 0 for ep_id in range(num_episodes): num_frames = done_indices[ep_id].item() + 1 total_frames += num_frames @@ -222,19 +224,20 @@ def eval_policy( if return_episode_data: ep_dict = { "action": actions[ep_id, :num_frames], - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([ep_id] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, "next.done": dones[ep_id, :num_frames], "next.reward": rewards[ep_id, :num_frames].type(torch.float32), - "episode_data_index_from": torch.tensor([idx_from] * num_frames), - "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), } for key in observations: ep_dict[key] = observations[key][ep_id][:num_frames] ep_dicts.append(ep_dict) - idx_from += num_frames + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from += num_frames # similar logic is implemented in dataset preprocessing if return_episode_data: @@ -247,14 +250,29 @@ def eval_policy( if key not in data_dict: data_dict[key] = [] for ep_dict in ep_dicts: - for x in ep_dict[key]: - # c h w -> h w c - img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) + for img in ep_dict[key]: + # sanity check that images are channel first + c, h, w = img.shape + assert c < h and c < w, f"expect channel first images, but instead {img.shape}" + + # sanity check that images are float32 in range [0,1] + assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}" + assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}" + assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}" + + # from float32 in range [0,1] to uint8 in range [0,255] + img *= 255 + img = img.type(torch.uint8) + + # convert to channel last and numpy as expected by PIL + img = PILImage.fromarray(img.permute(1, 2, 0).numpy()) + data_dict[key].append(img) data_dict["index"] = torch.arange(0, total_frames, 1) - hf_dataset = Dataset.from_dict(data_dict).with_format("torch") + hf_dataset = Dataset.from_dict(data_dict) + hf_dataset.set_transform(hf_transform_to_torch) if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) @@ -307,7 +325,10 @@ def eval_policy( }, } if return_episode_data: - info["episodes"] = hf_dataset + info["episodes"] = { + "hf_dataset": hf_dataset, + "episode_data_index": episode_data_index, + } if max_episodes_rendered > 0: info["videos"] = videos return info diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 4d8c247..8a70a21 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -136,6 +136,7 @@ def add_episodes_inplace( concat_dataset: torch.utils.data.ConcatDataset, sampler: torch.utils.data.WeightedRandomSampler, hf_dataset: datasets.Dataset, + episode_data_index: dict[str, torch.Tensor], pc_online_samples: float, ): """ @@ -151,13 +152,15 @@ def add_episodes_inplace( - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to reflect changes in the dataset sizes and specified sampling weights. - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. + - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. + They indicate the start index and end index of each episode in the dataset. - pc_online_samples (float): The target percentage of samples that should come from the online dataset during sampling operations. Raises: - AssertionError: If the first episode_id or index in hf_dataset is not 0 """ - first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item() + first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() first_index = hf_dataset.select_columns("index")[0]["index"].item() assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" @@ -167,21 +170,22 @@ def add_episodes_inplace( online_dataset.hf_dataset = hf_dataset else: # find episode index and data frame indices according to previous episode in online_dataset - start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 + start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1 start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 def shift_indices(example): - # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to - example["episode_id"] += start_episode + # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to + example["episode_index"] += start_episode example["index"] += start_index - example["episode_data_index_from"] += start_index - example["episode_data_index_to"] += start_index return example disable_progress_bars() # map has a tqdm progress bar hf_dataset = hf_dataset.map(shift_indices) enable_progress_bars() + episode_data_index["from"] += start_index + episode_data_index["to"] += start_index + # extend online dataset online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) @@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None): seed=cfg.seed, ) - online_pc_sampling = cfg.get("demo_schedule", 0.5) add_episodes_inplace( - online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling + online_dataset, + concat_dataset, + sampler, + hf_dataset=eval_info["episodes"]["hf_dataset"], + episode_data_index=eval_info["episodes"]["episode_data_index"], + pc_online_samples=cfg.get("demo_schedule", 0.5), ) for _ in range(cfg.policy.utd): diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 226fdc1..b51e62b 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict): def cat_and_write_video(video_path, frames, fps): - # Expects images in [0, 255]. frames = torch.cat(frames) - assert frames.dtype == torch.uint8 - frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() - imageio.mimsave(video_path, frames, fps=fps) + + # Expects images in [0, 1]. + frame = frames[0] + if frame.ndim == 4: + raise NotImplementedError("We currently dont support multiple timestamps.") + c, h, w = frame.shape + assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" + + # sanity check that images are float32 in range [0,1] + assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}" + assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}" + assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}" + + # convert to channel last uint8 [0, 255] + frames = einops.rearrange(frames, "b c h w -> b h w c") + frames = (frames * 255).type(torch.uint8) + imageio.mimsave(video_path, frames.numpy(), fps=fps) def visualize_dataset(cfg: dict, out_dir=None): @@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None): ) logging.info("Start rendering episodes from offline buffer") - video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) + video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER) for video_path in video_paths: logging.info(video_path) + return video_paths def render_dataset(dataset, out_dir, max_num_episodes): @@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes): # add current frame to list of frames to render frames[im_key].append(item[im_key]) - end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1 + end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1 out_dir.mkdir(parents=True, exist_ok=True) for im_key in dataset.image_keys: diff --git a/poetry.lock b/poetry.lock index a70e404..7b66604 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -522,21 +522,21 @@ toml = ["tomli"] [[package]] name = "datasets" -version = "2.18.0" +version = "2.19.0" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"}, - {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"}, + {file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"}, + {file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"}, ] [package.dependencies] aiohttp = "*" dill = ">=0.3.0,<0.3.9" filelock = "*" -fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]} -huggingface-hub = ">=0.19.4" +fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]} +huggingface-hub = ">=0.21.2" multiprocess = "*" numpy = ">=1.17" packaging = "*" @@ -552,15 +552,15 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] -docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] -tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] -tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +tensorflow = ["tensorflow (>=2.6.0)"] +tensorflow-gpu = ["tensorflow (>=2.6.0)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -2909,7 +2909,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4195,4 +4194,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c" +content-hash = "7f5afa48aead953f598e686e767891d3d23f2862b80144f76dc064101ef80b4a" diff --git a/pyproject.toml b/pyproject.toml index 0934898..a3fa4d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ pre-commit = {version = "^3.7.0", optional = true} debugpy = {version = "^1.8.1", optional = true} pytest = {version = "^8.1.0", optional = true} pytest-cov = {version = "^5.0.0", optional = true} -datasets = "^2.18.0" +datasets = "^2.19.0" + [tool.poetry.extras] pusht = ["gym-pusht"] diff --git a/tests/data/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors b/tests/data/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..8ed156c Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/aloha_sim_insertion_human/meta_data/info.json b/tests/data/aloha_sim_insertion_human/meta_data/info.json new file mode 100644 index 0000000..02e62b6 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 50 +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors b/tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors new file mode 100644 index 0000000..e956450 Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors differ diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow index 165298c..f6c89ff 100644 Binary files a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/aloha_sim_insertion_human/train/dataset_info.json index 542c7bf..69f8083 100644 --- a/tests/data/aloha_sim_insertion_human/train/dataset_info.json +++ b/tests/data/aloha_sim_insertion_human/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 14, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -37,14 +37,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_insertion_human/train/state.json b/tests/data/aloha_sim_insertion_human/train/state.json index 39101fd..153a412 100644 --- a/tests/data/aloha_sim_insertion_human/train/state.json +++ b/tests/data/aloha_sim_insertion_human/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "d79cf82ffc86f110", + "_fingerprint": "22eeca7a3f4725ee", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors b/tests/data/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..8685de7 Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/info.json b/tests/data/aloha_sim_insertion_scripted/meta_data/info.json new file mode 100644 index 0000000..02e62b6 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 50 +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/stats.safetensors b/tests/data/aloha_sim_insertion_scripted/meta_data/stats.safetensors new file mode 100644 index 0000000..619ee90 Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/meta_data/stats.safetensors differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow index 034f759..f5cdadc 100644 Binary files a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json index 542c7bf..69f8083 100644 --- a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json +++ b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 14, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -37,14 +37,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_insertion_scripted/train/state.json b/tests/data/aloha_sim_insertion_scripted/train/state.json index ecaa8fd..716aca6 100644 --- a/tests/data/aloha_sim_insertion_scripted/train/state.json +++ b/tests/data/aloha_sim_insertion_scripted/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "d8e4a817b5449498", + "_fingerprint": "97c28d4ad1536e4c", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors b/tests/data/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..8685de7 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/info.json b/tests/data/aloha_sim_transfer_cube_human/meta_data/info.json new file mode 100644 index 0000000..02e62b6 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 50 +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/stats.safetensors b/tests/data/aloha_sim_transfer_cube_human/meta_data/stats.safetensors new file mode 100644 index 0000000..998e610 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/meta_data/stats.safetensors differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow index 9682f00..1bb1f51 100644 Binary files a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json index 542c7bf..69f8083 100644 --- a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json +++ b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 14, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -37,14 +37,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_transfer_cube_human/train/state.json b/tests/data/aloha_sim_transfer_cube_human/train/state.json index 0167986..d9449a3 100644 --- a/tests/data/aloha_sim_transfer_cube_human/train/state.json +++ b/tests/data/aloha_sim_transfer_cube_human/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "f03482befa767127", + "_fingerprint": "cb9349b5c92951e8", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..8685de7 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/info.json b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/info.json new file mode 100644 index 0000000..02e62b6 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 50 +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors new file mode 100644 index 0000000..91696d3 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow index 567191d..d658a6d 100644 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json index 542c7bf..69f8083 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 14, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -37,14 +37,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json index 56005bc..2d4dfc6 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "93e03c6320c7d56e", + "_fingerprint": "e4d7ad2b360db1af", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/pusht/meta_data/episode_data_index.safetensors b/tests/data/pusht/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..9343d2d Binary files /dev/null and b/tests/data/pusht/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/pusht/meta_data/info.json b/tests/data/pusht/meta_data/info.json new file mode 100644 index 0000000..5c9a8ae --- /dev/null +++ b/tests/data/pusht/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 10 +} \ No newline at end of file diff --git a/tests/data/pusht/meta_data/stats.safetensors b/tests/data/pusht/meta_data/stats.safetensors new file mode 100644 index 0000000..fa2380e Binary files /dev/null and b/tests/data/pusht/meta_data/stats.safetensors differ diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/pusht/train/data-00000-of-00001.arrow index 9a36a8d..5972be9 100644 Binary files a/tests/data/pusht/train/data-00000-of-00001.arrow and b/tests/data/pusht/train/data-00000-of-00001.arrow differ diff --git a/tests/data/pusht/train/dataset_info.json b/tests/data/pusht/train/dataset_info.json index 667e06f..aefe478 100644 --- a/tests/data/pusht/train/dataset_info.json +++ b/tests/data/pusht/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 2, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -45,14 +45,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/pusht/train/meta_data/episode_data_index.safetensors b/tests/data/pusht/train/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..3511c26 Binary files /dev/null and b/tests/data/pusht/train/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/pusht/train/meta_data/info.json b/tests/data/pusht/train/meta_data/info.json new file mode 100644 index 0000000..5c9a8ae --- /dev/null +++ b/tests/data/pusht/train/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 10 +} \ No newline at end of file diff --git a/tests/data/pusht/train/meta_data/stats_action.safetensors b/tests/data/pusht/train/meta_data/stats_action.safetensors new file mode 100644 index 0000000..2c2553b Binary files /dev/null and b/tests/data/pusht/train/meta_data/stats_action.safetensors differ diff --git a/tests/data/pusht/train/meta_data/stats_observation.image.safetensors b/tests/data/pusht/train/meta_data/stats_observation.image.safetensors new file mode 100644 index 0000000..0a145d4 Binary files /dev/null and b/tests/data/pusht/train/meta_data/stats_observation.image.safetensors differ diff --git a/tests/data/pusht/train/meta_data/stats_observation.state.safetensors b/tests/data/pusht/train/meta_data/stats_observation.state.safetensors new file mode 100644 index 0000000..28ee285 Binary files /dev/null and b/tests/data/pusht/train/meta_data/stats_observation.state.safetensors differ diff --git a/tests/data/pusht/train/state.json b/tests/data/pusht/train/state.json index 7e0ff57..dda3f88 100644 --- a/tests/data/pusht/train/state.json +++ b/tests/data/pusht/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "21bb9a76ed78a475", + "_fingerprint": "a04a9ce660122e23", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors new file mode 100644 index 0000000..1bb7c06 Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors new file mode 100644 index 0000000..ae46012 Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_159.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_159.safetensors new file mode 100644 index 0000000..2b5729d Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_159.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_160.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_160.safetensors new file mode 100644 index 0000000..a048c0c Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_160.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_80.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_80.safetensors new file mode 100644 index 0000000..e37d54c Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_80.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_81.safetensors b/tests/data/save_dataset_to_safetensors/pusht/frame_81.safetensors new file mode 100644 index 0000000..5cd8451 Binary files /dev/null and b/tests/data/save_dataset_to_safetensors/pusht/frame_81.safetensors differ diff --git a/tests/data/xarm_lift_medium/meta_data/episode_data_index.safetensors b/tests/data/xarm_lift_medium/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..7216093 Binary files /dev/null and b/tests/data/xarm_lift_medium/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/xarm_lift_medium/meta_data/info.json b/tests/data/xarm_lift_medium/meta_data/info.json new file mode 100644 index 0000000..f9d6b30 --- /dev/null +++ b/tests/data/xarm_lift_medium/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 15 +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/meta_data/stats.safetensors b/tests/data/xarm_lift_medium/meta_data/stats.safetensors new file mode 100644 index 0000000..bdcc1b0 Binary files /dev/null and b/tests/data/xarm_lift_medium/meta_data/stats.safetensors differ diff --git a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow index 45d527e..d621210 100644 Binary files a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow and b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/xarm_lift_medium/train/dataset_info.json index bb647c4..59a43bd 100644 --- a/tests/data/xarm_lift_medium/train/dataset_info.json +++ b/tests/data/xarm_lift_medium/train/dataset_info.json @@ -21,11 +21,11 @@ "length": 4, "_type": "Sequence" }, - "episode_id": { + "episode_index": { "dtype": "int64", "_type": "Value" }, - "frame_id": { + "frame_index": { "dtype": "int64", "_type": "Value" }, @@ -41,14 +41,6 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/xarm_lift_medium/train/state.json index c930c52..642fda3 100644 --- a/tests/data/xarm_lift_medium/train/state.json +++ b/tests/data/xarm_lift_medium/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "a95cbec45e3bb9d6", + "_fingerprint": "cc6afdfcdd6f63ab", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..7216093 Binary files /dev/null and b/tests/data/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/xarm_lift_medium_replay/meta_data/info.json b/tests/data/xarm_lift_medium_replay/meta_data/info.json new file mode 100644 index 0000000..f9d6b30 --- /dev/null +++ b/tests/data/xarm_lift_medium_replay/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 15 +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors b/tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors new file mode 100644 index 0000000..4808895 Binary files /dev/null and b/tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors differ diff --git a/tests/data/xarm_lift_medium_replay/train/data-00000-of-00001.arrow b/tests/data/xarm_lift_medium_replay/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..b524811 Binary files /dev/null and b/tests/data/xarm_lift_medium_replay/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_lift_medium_replay/train/dataset_info.json b/tests/data/xarm_lift_medium_replay/train/dataset_info.json new file mode 100644 index 0000000..59a43bd --- /dev/null +++ b/tests/data/xarm_lift_medium_replay/train/dataset_info.json @@ -0,0 +1,51 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "episode_index": { + "dtype": "int64", + "_type": "Value" + }, + "frame_index": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium_replay/train/state.json b/tests/data/xarm_lift_medium_replay/train/state.json new file mode 100644 index 0000000..e9b74d7 --- /dev/null +++ b/tests/data/xarm_lift_medium_replay/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "9f8e1a8c1845df55", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium/meta_data/episode_data_index.safetensors b/tests/data/xarm_push_medium/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..7216093 Binary files /dev/null and b/tests/data/xarm_push_medium/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/xarm_push_medium/meta_data/info.json b/tests/data/xarm_push_medium/meta_data/info.json new file mode 100644 index 0000000..f9d6b30 --- /dev/null +++ b/tests/data/xarm_push_medium/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 15 +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium/meta_data/stats.safetensors b/tests/data/xarm_push_medium/meta_data/stats.safetensors new file mode 100644 index 0000000..f216e05 Binary files /dev/null and b/tests/data/xarm_push_medium/meta_data/stats.safetensors differ diff --git a/tests/data/xarm_push_medium/train/data-00000-of-00001.arrow b/tests/data/xarm_push_medium/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..241117c Binary files /dev/null and b/tests/data/xarm_push_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_push_medium/train/dataset_info.json b/tests/data/xarm_push_medium/train/dataset_info.json new file mode 100644 index 0000000..9e47b34 --- /dev/null +++ b/tests/data/xarm_push_medium/train/dataset_info.json @@ -0,0 +1,51 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 3, + "_type": "Sequence" + }, + "episode_index": { + "dtype": "int64", + "_type": "Value" + }, + "frame_index": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium/train/state.json b/tests/data/xarm_push_medium/train/state.json new file mode 100644 index 0000000..0ec1f04 --- /dev/null +++ b/tests/data/xarm_push_medium/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "c900258061dd0b3f", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/xarm_push_medium_replay/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..7216093 Binary files /dev/null and b/tests/data/xarm_push_medium_replay/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/xarm_push_medium_replay/meta_data/info.json b/tests/data/xarm_push_medium_replay/meta_data/info.json new file mode 100644 index 0000000..f9d6b30 --- /dev/null +++ b/tests/data/xarm_push_medium_replay/meta_data/info.json @@ -0,0 +1,3 @@ +{ + "fps": 15 +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium_replay/meta_data/stats.safetensors b/tests/data/xarm_push_medium_replay/meta_data/stats.safetensors new file mode 100644 index 0000000..0de4755 Binary files /dev/null and b/tests/data/xarm_push_medium_replay/meta_data/stats.safetensors differ diff --git a/tests/data/xarm_push_medium_replay/train/data-00000-of-00001.arrow b/tests/data/xarm_push_medium_replay/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..2e07ea9 Binary files /dev/null and b/tests/data/xarm_push_medium_replay/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_push_medium_replay/train/dataset_info.json b/tests/data/xarm_push_medium_replay/train/dataset_info.json new file mode 100644 index 0000000..9e47b34 --- /dev/null +++ b/tests/data/xarm_push_medium_replay/train/dataset_info.json @@ -0,0 +1,51 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 3, + "_type": "Sequence" + }, + "episode_index": { + "dtype": "int64", + "_type": "Value" + }, + "frame_index": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/xarm_push_medium_replay/train/state.json b/tests/data/xarm_push_medium_replay/train/state.json new file mode 100644 index 0000000..39ffeaf --- /dev/null +++ b/tests/data/xarm_push_medium_replay/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "e51c80a33c7688c0", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py new file mode 100644 index 0000000..4f0875e --- /dev/null +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -0,0 +1,71 @@ +""" +This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility +when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the +dataset into a corresponding safetensors file in a specified output directory. + +If you know that your change will break backward compatibility, you should write a shortlived test by modifying +`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test +doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. + +Example usage: + `python tests/script/save_dataset_to_safetensors.py` +""" + +import shutil +from pathlib import Path + +from safetensors.torch import save_file + +from lerobot.common.datasets.pusht import PushtDataset + + +def save_dataset_to_safetensors(output_dir, dataset_id="pusht"): + data_dir = Path(output_dir) / dataset_id + + if data_dir.exists(): + shutil.rmtree(data_dir) + + data_dir.mkdir(parents=True, exist_ok=True) + + # TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id) + dataset = PushtDataset( + dataset_id=dataset_id, + split="train", + ) + + # save 2 first frames of first episode + i = dataset.episode_data_index["from"][0].item() + save_file(dataset[i], data_dir / f"frame_{i}.safetensors") + save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors") + + # save 2 frames at the middle of first episode + i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + save_file(dataset[i], data_dir / f"frame_{i}.safetensors") + save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors") + + # save 2 last frames of first episode + i = dataset.episode_data_index["to"][0].item() + save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors") + save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors") + + # TODO(rcadene): Enable testing on second and last episode + # We currently cant because our test dataset only contains the first episode + + # # save 2 first frames of second episode + # i = dataset.episode_data_index["from"][1].item() + # save_file(dataset[i], data_dir / f"frame_{i}.safetensors") + # save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors") + + # # save 2 last frames of second episode + # i = dataset.episode_data_index["to"][1].item() + # save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors") + # save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors") + + # # save 2 last frames of last episode + # i = dataset.episode_data_index["to"][-1].item() + # save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors") + # save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors") + + +if __name__ == "__main__": + save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e488c30..ec459c5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,20 +1,26 @@ +import json import logging import os +from copy import deepcopy from pathlib import Path import einops import pytest import torch from datasets import Dataset +from safetensors.torch import load_file import lerobot from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.datasets.utils import ( compute_stats, + flatten_dict, get_stats_einops_patterns, + hf_transform_to_torch, load_previous_and_future_frames, + unflatten_dict, ) -from lerobot.common.transforms import Prod from lerobot.common.utils.utils import init_hydra_config from .utils import DEFAULT_CONFIG_PATH, DEVICE @@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required = [ ("action", 1, True), - ("episode_id", 0, True), - ("frame_id", 0, True), + ("episode_index", 0, True), + ("frame_index", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? ("observation.state", 1, True), @@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name): ("next.done", 0, False), ] - for key in image_keys: - keys_ndim_required.append( - (key, 3, True), - ) - assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}" - # test number of dimensions for key, ndim, required in keys_ndim_required: if key not in item: @@ -94,26 +94,21 @@ def test_compute_stats_on_xarm(): We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do because we are working with a small dataset). """ + # TODO(rcadene): Reduce size of dataset sample on which stats compute is tested from lerobot.common.datasets.xarm import XarmDataset - data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None - - # get transform to convert images from uint8 [0,255] to float32 [0,1] - transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) - dataset = XarmDataset( dataset_id="xarm_lift_medium", - root=data_dir, - transform=transform, + root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, ) # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the # dataset into even batches. - computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25)) # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) + stats_patterns = get_stats_einops_patterns(dataset.hf_dataset) # get all frames from the dataset in the same dtype and range as during compute_stats dataloader = torch.utils.data.DataLoader( @@ -122,18 +117,19 @@ def test_compute_stats_on_xarm(): batch_size=len(dataset), shuffle=False, ) - hf_dataset = next(iter(dataloader)) + full_batch = next(iter(dataloader)) # compute stats based on all frames from the dataset without any batching expected_stats = {} for k, pattern in stats_patterns.items(): + full_batch[k] = full_batch[k].float() expected_stats[k] = {} - expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean") + expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") expected_stats[k]["std"] = torch.sqrt( - einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") + einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") ) - expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min") - expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max") + expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") # test computed stats match expected stats for k in stats_patterns: @@ -142,11 +138,10 @@ def test_compute_stats_on_xarm(): assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) - # TODO(rcadene): check that the stats used for training are correct too - # # load stats that are expected to match the ones returned by computed_stats - # assert (dataset.data_dir / "stats.pth").exists() - # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + # load stats used during training which are expected to match the ones returned by computed_stats + loaded_stats = dataset.stats # noqa: F841 + # TODO(rcadene): we can't test this because expected_stats is computed on a subset # # test loaded stats match expected stats # for k in stats_patterns: # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) @@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance(): { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_index": [0, 0, 0, 0, 0], } ) - hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + hf_dataset.set_transform(hf_transform_to_torch) + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.2, 0, 0.139]} tol = 0.04 - item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + item = hf_dataset[2] + item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" assert not is_pad.any(), "Unexpected padding detected" @@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_index": [0, 0, 0, 0, 0], } ) - hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + hf_dataset.set_transform(hf_transform_to_torch) + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.2, 0, 0.141]} tol = 0.04 + item = hf_dataset[2] with pytest.raises(AssertionError): - load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): @@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_index": [0, 0, 0, 0, 0], } ) - hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + hf_dataset.set_transform(hf_transform_to_torch) + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} tol = 0.04 - item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + item = hf_dataset[2] + item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal( is_pad, torch.tensor([True, False, False, True, True]) ), "Padding does not match expected values" + + +def test_flatten_unflatten_dict(): + d = { + "obs": { + "min": 0, + "max": 1, + "mean": 2, + "std": 3, + }, + "action": { + "min": 4, + "max": 5, + "mean": 6, + "std": 7, + }, + } + + original_d = deepcopy(d) + d = unflatten_dict(flatten_dict(d)) + + # test equality between nested dicts + assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" + + +def test_backward_compatibility(): + """This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" + # TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id) + dataset_id = "pusht" + data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id + + dataset = PushtDataset( + dataset_id=dataset_id, + split="train", + root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, + ) + + def load_and_compare(i): + new_frame = dataset[i] + old_frame = load_file(data_dir / f"frame_{i}.safetensors") + + new_keys = set(new_frame.keys()) + old_keys = set(old_frame.keys()) + assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" + + for key in new_frame: + assert ( + new_frame[key] == old_frame[key] + ).all(), f"{key=} for index={i} does not contain the same value" + + # test2 first frames of first episode + i = dataset.episode_data_index["from"][0].item() + load_and_compare(i) + load_and_compare(i + 1) + + # test 2 frames at the middle of first episode + i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + load_and_compare(i) + load_and_compare(i + 1) + + # test 2 last frames of first episode + i = dataset.episode_data_index["to"][0].item() + load_and_compare(i - 2) + load_and_compare(i - 1) + + # TODO(rcadene): Enable testing on second and last episode + # We currently cant because our test dataset only contains the first episode + + # # test 2 first frames of second episode + # i = dataset.episode_data_index["from"][1].item() + # load_and_compare(i) + # load_and_compare(i+1) + + # #test 2 last frames of second episode + # i = dataset.episode_data_index["to"][1].item() + # load_and_compare(i-2) + # load_and_compare(i-1) + + # # test 2 last frames of last episode + # i = dataset.episode_data_index["to"][-1].item() + # load_and_compare(i-2) + # load_and_compare(i-1) diff --git a/tests/test_examples.py b/tests/test_examples.py index a3f90cf..3ac040b 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,3 +1,4 @@ +# TODO(aliberts): Mute logging for these tests import subprocess from pathlib import Path diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py new file mode 100644 index 0000000..6787c46 --- /dev/null +++ b/tests/test_visualize_dataset.py @@ -0,0 +1,31 @@ +import pytest + +from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.visualize_dataset import visualize_dataset + +from .utils import DEFAULT_CONFIG_PATH + + +@pytest.mark.parametrize( + "dataset_id", + [ + "aloha_sim_insertion_human", + ], +) +def test_visualize_dataset(tmpdir, dataset_id): + # TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset + # doesnt support multiple timesteps which requires delta_timestamps to None for images. + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + "policy=act", + "env=aloha", + f"dataset_id={dataset_id}", + ], + ) + video_paths = visualize_dataset(cfg, out_dir=tmpdir) + + assert len(video_paths) > 0 + + for video_path in video_paths: + assert video_path.exists()