Compare commits

..

12 Commits

Author SHA1 Message Date
Remi Cadene
7cee7a0f20 Add mobile Aloha and visu with rerun.io 2024-04-20 16:19:55 +02:00
Cadene
2a59825a00 fix online training 2024-04-20 00:12:34 +00:00
Cadene
06628ba059 fix online training 2024-04-19 23:58:38 +00:00
Cadene
b2b5329683 fix online training 2024-04-19 23:48:43 +00:00
Cadene
85f1554da8 fix visualize_dataset 2024-04-19 23:40:35 +00:00
Cadene
9b4c2e2a9f small fix 2024-04-19 23:30:39 +00:00
Cadene
20928021c0 Add tests/data 2024-04-19 23:27:11 +00:00
Cadene
c20cf2fbbc Remove Prod, Tests are passind 2024-04-19 23:27:10 +00:00
Cadene
35a573c98e Use v1.1, hf_transform_to_torch, Add 3 xarm datasets 2024-04-19 23:26:13 +00:00
Cadene
714a776277 id -> index, finish moving compute_stats before hf_dataset push_to_hub 2024-04-19 23:25:06 +00:00
Cadene
64b09ea7a7 WIP add load functions + episode_data_index 2024-04-19 23:24:08 +00:00
Cadene
0bd2ca8d82 Add meta_data, revision v1.1 2024-04-19 23:24:08 +00:00
116 changed files with 828 additions and 972 deletions

31
.github/poetry/cpu/poetry.lock generated vendored
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -522,21 +522,21 @@ toml = ["tomli"]
[[package]]
name = "datasets"
version = "2.19.0"
version = "2.18.0"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
]
[package.dependencies]
aiohttp = "*"
dill = ">=0.3.0,<0.3.9"
filelock = "*"
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
huggingface-hub = ">=0.21.2"
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
huggingface-hub = ">=0.19.4"
multiprocess = "*"
numpy = ">=1.17"
packaging = "*"
@@ -552,15 +552,15 @@ xxhash = "*"
apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.6.0)"]
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
@@ -1524,6 +1524,7 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li
optional = true
python-versions = ">=3.6"
files = [
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
@@ -1533,6 +1534,7 @@ files = [
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
@@ -1542,6 +1544,7 @@ files = [
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
@@ -1567,8 +1570,8 @@ files = [
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
{file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
@@ -1576,6 +1579,7 @@ files = [
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
@@ -2684,6 +2688,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -3914,4 +3919,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e526416d1282dea2550680b2be7fcf9ff6e1c67ac89d34c684b486d94a6addee"
content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e"

View File

@@ -53,7 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0"
datasets = "^2.18.0"
[tool.poetry.extras]

View File

@@ -193,9 +193,8 @@ jobs:
env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=2 \
online_steps=1 \
eval_episodes=1 \
env.episode_length=2 \
device=cpu \
save_model=true \
save_freq=2 \

View File

@@ -73,14 +73,15 @@ environments ([aloha](https://github.com/huggingface/gym-aloha),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Copy it in the required `available_datasets` class attribute
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class

View File

@@ -118,7 +118,30 @@ wandb login
### Visualize datasets
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
import os
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']
```
Or you can achieve the same result by executing our script from the command line:
```bash
@@ -130,7 +153,7 @@ hydra.run.dir=outputs/visualize_dataset/example
### Evaluate a pretrained policy
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Or you can achieve the same result by executing our script from the command line:
```bash
@@ -153,30 +176,24 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
In general, you can use our training script to easily train any policy on any environment:
```bash
python lerobot/scripts/train.py \
env=aloha \
task=sim_insertion \
repo_id=lerobot/aloha_sim_insertion_scripted \
dataset_id=aloha_sim_insertion_scripted \
policy=act \
hydra.run.dir=outputs/train/aloha_act
```
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
### Add a new dataset
```python
# TODO(rcadene, AdilZouitine): rewrite this section
```
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
@@ -238,10 +255,6 @@ python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir
### Add a pretrained policy
```python
# TODO(rcadene, alexander-soare): rewrite this section
```
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
@@ -250,13 +263,15 @@ Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
```
to_upload
├── config.yaml
── model.pt
── model.pt
└── stats.pth
```
With the folder prepared, run the following with a desired revision ID.

View File

@@ -23,13 +23,14 @@ from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transf
def download_and_upload(root, revision, dataset_id):
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
if "pusht" in dataset_id:
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha" in dataset_id:
elif "aloha_sim" in dataset_id:
download_and_upload_aloha(root, revision, dataset_id)
elif "aloha_mobile" in dataset_id:
download_and_upload_aloha_mobile(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
@@ -150,11 +151,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/lerobot/{dataset_id}/train"
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data")
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
@@ -531,20 +532,125 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
num_episodes = {
"aloha_mobile_trossen_block_handoff": 5,
}
# episode_len = {
# "aloha_sim_insertion_human": 500,
# "aloha_sim_insertion_scripted": 400,
# "aloha_sim_transfer_cube_human": 400,
# "aloha_sim_transfer_cube_scripted": 400,
# }
cameras = {
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
#assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
action = torch.from_numpy(ep["/action"][:num_frames])
ep_dict = {}
for cam in cameras[dataset_id]:
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
import cv2
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
image = np.array([cv2.imdecode(x, 1) for x in image])
image = [PILImage.fromarray(x) for x in image]
ep_dict[f"observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
break
data_dict = concatenate_episodes(ep_dicts)
features = {}
for cam in cameras[dataset_id]:
features[f"observation.images.{cam}"] = Image()
features.update({
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
})
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":
root = "data"
revision = "v1.1"
dataset_ids = [
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
# "pusht",
# "xarm_lift_medium",
# "xarm_lift_medium_replay",
# "xarm_push_medium",
# "xarm_push_medium_replay",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
# "aloha_sim_transfer_cube_scripted",
"aloha_mobile_trossen_block_handoff",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

View File

@@ -10,9 +10,9 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
@@ -34,7 +34,6 @@ hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="t
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
```
"""
# TODO(rcadene): remove this example file of using hf_dataset
from pathlib import Path
@@ -44,10 +43,9 @@ from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features
# TODO(rcadene): update to make the print pretty
print(f"{hf_dataset=}")
print(f"{hf_dataset.features=}")

View File

@@ -1,5 +1,5 @@
"""
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
@@ -11,6 +11,22 @@ Features included in this script:
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
To try a different Hugging Face dataset, you can replace:
```python
dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
```
"""
from pathlib import Path
@@ -18,43 +34,38 @@ from pathlib import Path
import imageio
import torch
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.pusht import PushtDataset
print("List of available datasets", lerobot.available_datasets)
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
# print("List of available datasets", lerobot.available_datasets)
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
# # 'pusht', 'xarm_lift_medium']
repo_id = "lerobot/pusht"
# You can easily load a dataset from a Hugging Face repositery
dataset = LeRobotDataset(repo_id)
# You can easily load datasets from LeRobot
dataset = PushtDataset()
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# TODO(rcadene): update to make the print pretty
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
# and provides additional utilities for robotics and compatibility with pytorch
# and provide additional utilities for robotics and compatibility with pytorch
print(f"number of samples/frames: {dataset.num_samples=}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
# TODO(rcadene): remove this example of accessing hf_dataset
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
# to view them, we convert to uint8 range [0,255]
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c)
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video
@@ -71,7 +82,7 @@ delta_timestamps = {
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
dataset = PushtDataset(delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c)

View File

@@ -19,6 +19,7 @@ folder = Path(snapshot_download(hub_id))
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
@@ -35,4 +36,5 @@ cfg = init_hydra_config(config_path, overrides)
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
policy.train()
policy.to(device)
@@ -62,6 +62,7 @@ while not done:
done = True
break
# Save the policy and configuration for later use.
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@@ -8,25 +8,31 @@ Example:
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
```
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
"xarm",
]
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -35,24 +41,22 @@ available_tasks_per_env = {
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
available_datasets_per_env = {
available_datasets = {
"aloha": [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["lerobot/pusht"],
"pusht": ["pusht"],
"xarm": [
"lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
available_policies = [
"act",
@@ -67,12 +71,10 @@ available_policies_per_env = {
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets_per_env.items()
for env, datasets in available_datasets.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@@ -0,0 +1,78 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
# Copied from lerobot/__init__.py
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50
image_keys = ["observation.images.top"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@@ -1,22 +1,76 @@
import logging
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from torchvision.transforms import v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.env.name not in cfg.dataset.repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)
delta_timestamps = cfg.policy.get("delta_timestamps")
@@ -25,20 +79,12 @@ def make_dataset(
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset.repo_id,
dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@@ -0,0 +1,76 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
delta_timestamps : dict[list[float]] | None, optional
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
# Copied from lerobot/__init__.py
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@@ -1,4 +1,3 @@
import json
from copy import deepcopy
from math import ceil
from pathlib import Path
@@ -13,16 +12,10 @@ from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.utils.utils import set_global_seed
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
@@ -62,17 +55,19 @@ def hf_transform_to_torch(items_dict):
return items_dict
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
@@ -83,8 +78,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
@@ -92,7 +88,7 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
return load_file(path)
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@@ -101,32 +97,15 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
with open(path) as f:
info = json.load(f)
return info
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
@@ -216,7 +195,7 @@ def load_previous_and_future_frames(
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images of `hf_dataset` are in channel first format
Note: We assume the images are returned in channel first format
"""
dataloader = torch.utils.data.DataLoader(
@@ -268,15 +247,13 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
set_global_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader

View File

@@ -1,21 +1,36 @@
from pathlib import Path
import datasets
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
)
class LeRobotDataset(torch.utils.data.Dataset):
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
repo_id: str,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
@@ -23,25 +38,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_id = repo_id
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
@property
def fps(self) -> int:
return self.info["fps"]
@property
def image_keys(self) -> list[str]:
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:

View File

@@ -39,5 +39,4 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
for _ in range(num_parallel_envs)
]
)
return env

View File

@@ -1,8 +1,10 @@
import einops
import torch
from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation):
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {}
@@ -22,7 +24,7 @@ def preprocess_observation(observation):
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
@@ -31,11 +33,19 @@ def preprocess_observation(observation):
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs
def postprocess_action(action):
action = action.to("cpu").numpy()
def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@@ -1,36 +1,30 @@
from transformers.configuration_utils import PretrainedConfig
from dataclasses import dataclass, field
class ActionChunkingTransformerConfig(PretrainedConfig):
@dataclass
class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Those are: `state_dim`, `action_dim` and `camera_names`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
@@ -54,41 +48,23 @@ class ActionChunkingTransformerConfig(PretrainedConfig):
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
Example:
# Environment.
state_dim: int = 14
action_dim: int = 14
```python
>>> from lerobot import ActionChunkingTransformerConfig
>>> # Initializing an ACT style configuration
>>> configuration = ActionChunkingTransformerConfig()
>>> # Initializing a model (with random weights) from the ACT style configuration
>>> model = ActionChunkingTransformerPolicy(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Input / output structure.
# Inputs / output structure.
n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, list[str]] = {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
output_shapes: dict[str, list[str]] = {"action": [14]}
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Architecture.
# Vision backbone.
@@ -141,10 +117,7 @@ class ActionChunkingTransformerConfig(PretrainedConfig):
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
if self.camera_names != ["top"]:
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")

View File

@@ -15,12 +15,12 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module):
@@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -72,8 +72,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -81,13 +79,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
@@ -99,6 +93,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
@@ -115,7 +112,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
@@ -129,7 +126,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
self._reset_parameters()
self._create_optimizer()
@@ -172,18 +169,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
@@ -214,11 +203,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
@@ -247,9 +232,17 @@ class ActionChunkingTransformerPolicy(nn.Module):
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
dim=-4,
)
@@ -316,8 +309,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
@@ -8,28 +8,21 @@ class DiffusionConfig:
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `state_dim`, `action_dim` and `image_size`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
image_size: (H, W) size of the input images.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -65,35 +58,20 @@ class DiffusionConfig:
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
"""
# Environment.
# Inherit these from the environment config.
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
# Architecture / modeling.
# Vision backbone.
@@ -145,14 +123,10 @@ class DiffusionConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if (
self.crop_shape[0] > self.input_shapes["observation.image"][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
):
if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]:
raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
'`input_shapes["observation.image"]`.'
f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and "
f"{self.image_size} for `image_size`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:

View File

@@ -13,6 +13,7 @@ import logging
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
@@ -26,7 +27,6 @@ from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
@@ -42,9 +42,7 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -56,8 +54,6 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -130,8 +126,6 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
@@ -141,10 +135,6 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@@ -161,13 +151,9 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
@@ -211,8 +197,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
self.rgb_encoder = _RgbEncoder(cfg)
self.unet = _ConditionalUnet1D(
cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
)
self.noise_scheduler = DDPMScheduler(
@@ -240,7 +225,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
dtype=dtype,
device=device,
generator=generator,
@@ -283,7 +268,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.output_shapes["action"][0]]
actions = sample[..., : self.cfg.action_dim]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.cfg.n_action_steps
@@ -361,6 +346,12 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
@@ -393,9 +384,7 @@ class _RgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
)
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -408,7 +397,8 @@ class _RgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
@@ -512,7 +502,7 @@ class _ConditionalUnet1D(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
)
@@ -563,7 +553,7 @@ class _ConditionalUnet1D(nn.Module):
self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:

View File

@@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
def make_policy(hydra_cfg: DictConfig):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(hydra_cfg.policy.name)

View File

@@ -1,196 +0,0 @@
import torch
from torch import nn
def create_stats_buffers(shapes, modes, stats=None):
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
`requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = tuple(shapes[key])
if "image" in key:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if mode == "mean_std":
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif mode == "min_max":
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
if stats is not None:
if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"]
buffer["std"].data = stats[key]["std"]
elif mode == "min_max":
buffer["min"].data = stats[key]["min"]
buffer["max"].data = stats[key]["max"]
stats_buffers[key] = buffer
return stats_buffers
class Normalize(nn.Module):
"""
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
Parameters:
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
- "mean_std": multiply by standard deviation and add mean
- "min_max": go from [-1, 1] range to original range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@@ -0,0 +1,65 @@
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@@ -26,8 +26,7 @@ fps: ???
offline_prioritized_sampler: true
dataset:
repo_id: ???
dataset_id: ???
n_action_steps: ???
n_obs_steps: ???

View File

@@ -10,8 +10,7 @@ online_steps: 25000
fps: 50
dataset:
repo_id: lerobot/aloha_sim_insertion_human
dataset_id: aloha_sim_insertion_human
env:
name: aloha
@@ -21,5 +20,7 @@ env:
image_size: [3, 480, 640]
episode_length: 400
fps: ${fps}
policy:
state_dim: 14
action_dim: 14

View File

@@ -10,8 +10,7 @@ online_steps: 25000
fps: 10
dataset:
repo_id: lerobot/pusht
dataset_id: pusht
env:
name: pusht
@@ -21,5 +20,7 @@ env:
image_size: 96
episode_length: 300
fps: ${fps}
policy:
state_dim: 2
action_dim: 2

View File

@@ -9,8 +9,7 @@ online_steps: 25000
fps: 15
dataset:
repo_id: lerobot/xarm_lift_medium
dataset_id: xarm_lift_medium
env:
name: xarm
@@ -20,5 +19,7 @@ env:
image_size: 84
episode_length: 25
fps: ${fps}
policy:
state_dim: 4
action_dim: 4

View File

@@ -11,36 +11,26 @@ log_freq: 250
n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
# See `configuration_act.py` for more details.
policy:
name: act
pretrained_model_path:
# Input / output structure.
# Environment.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???
# Inputs / output structure.
n_obs_steps: ${n_obs_steps}
camera_names: [top] # [top, front_close, left_pillar, right_pillar]
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.images.top: mean_std
observation.state: mean_std
unnormalize_output_modes:
action: mean_std
# Vision preprocessing.
image_normalization_mean: [0.485, 0.456, 0.406]
image_normalization_std: [0.229, 0.224, 0.225]
# Architecture.
# Vision backbone.

View File

@@ -18,43 +18,27 @@ online_steps: 0
offline_prioritized_sampler: true
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy:
name: diffusion
pretrained_model_path:
# Input / output structure.
# Environment.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???
image_size:
- ${env.image_size} # height
- ${env.image_size} # width
# Inputs / output structure.
n_obs_steps: ${n_obs_steps}
horizon: ${horizon}
n_action_steps: ${n_action_steps}
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
action: min_max
# Vision preprocessing.
image_normalization_mean: [0.5, 0.5, 0.5]
image_normalization_std: [0.5, 0.5, 0.5]
# Architecture / modeling.
# Vision backbone.

View File

@@ -16,8 +16,8 @@ policy:
frame_stack: 1
num_channels: 32
img_size: ${env.image_size}
state_dim: ${env.action_dim}
action_dim: ${env.action_dim}
state_dim: ???
action_dim: ???
# planning
mpc: true

View File

@@ -41,11 +41,12 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from datasets import Dataset
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@@ -63,6 +64,8 @@ def eval_policy(
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False,
seed=None,
):
@@ -129,6 +132,10 @@ def eval_policy(
if return_episode_data:
observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@@ -136,8 +143,8 @@ def eval_policy(
with torch.inference_mode():
action = policy.select_action(observation, step=step)
# convert to cpu numpy
action = postprocess_action(action)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
@@ -263,34 +270,8 @@ def eval_policy(
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
# TODO(rcadene): clean this
features = {}
for key in observations:
if "image" in key:
features[key] = Image()
else:
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
features.update(
{
"action": Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
)
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = Dataset.from_dict(data_dict)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:
@@ -353,7 +334,7 @@ def eval_policy(
return info
def eval(cfg: dict, out_dir=None):
def eval(cfg: dict, out_dir=None, stats_path=None):
if out_dir is None:
raise NotImplementedError()
@@ -368,6 +349,10 @@ def eval(cfg: dict, out_dir=None):
log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@@ -379,6 +364,7 @@ def eval(cfg: dict, out_dir=None):
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False,
seed=cfg.seed,
)
@@ -411,13 +397,17 @@ if __name__ == "__main__":
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -160,32 +160,27 @@ def add_episodes_inplace(
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(episode_index, index):
def shift_indices(example):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
example["episode_index"] += start_episode
example["index"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
@@ -232,7 +227,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
policy = make_policy(cfg)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -311,7 +306,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
@@ -339,6 +333,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info = eval_policy(
rollout_env,
policy,
transform=offline_dataset.transform,
return_episode_data=True,
seed=cfg.seed,
)

View File

@@ -26,9 +26,7 @@ def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 1].
frame = frames[0]
if frame.ndim == 4:
raise NotImplementedError("We currently dont support multiple timestamps.")
c, h, w = frame.shape
_, c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
# sanity check that images are float32 in range [0,1]
@@ -50,13 +48,16 @@ def visualize_dataset(cfg: dict, out_dir=None):
log_output_dir(out_dir)
logging.info("make_dataset")
dataset = make_dataset(cfg)
dataset = make_dataset(
cfg,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
return video_paths
def render_dataset(dataset, out_dir, max_num_episodes):
@@ -87,7 +88,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:

23
poetry.lock generated
View File

@@ -522,21 +522,21 @@ toml = ["tomli"]
[[package]]
name = "datasets"
version = "2.19.0"
version = "2.18.0"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
]
[package.dependencies]
aiohttp = "*"
dill = ">=0.3.0,<0.3.9"
filelock = "*"
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
huggingface-hub = ">=0.21.2"
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
huggingface-hub = ">=0.19.4"
multiprocess = "*"
numpy = ">=1.17"
packaging = "*"
@@ -552,15 +552,15 @@ xxhash = "*"
apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.6.0)"]
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
@@ -2909,6 +2909,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -4194,4 +4195,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "7f5afa48aead953f598e686e767891d3d23f2862b80144f76dc064101ef80b4a"
content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c"

View File

@@ -53,8 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0"
datasets = "^2.18.0"
[tool.poetry.extras]
pusht = ["gym-pusht"]

30
test.py Normal file
View File

@@ -0,0 +1,30 @@
import rerun as rr
from datasets import load_from_disk
# download/load dataset in pyarrow format
print("Loading dataset…")
#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train")
dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train")
# select the frames belonging to episode number 5
print("Select specific episode…")
print("Starting Rerun…")
rr.init("rerun_example_lerobot", spawn=True)
print("Logging to Rerun…")
# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip(
for d in dataset:
rr.set_time_sequence("frame_index", d["frame_index"])
rr.set_time_seconds("timestamp", d["timestamp"])
rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"]))
rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"]))
rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"]))
#rr.log("observation/state", rr.BarChart(state))
#rr.log("observation/action", rr.BarChart(action))
for idx, val in enumerate(d["action"]):
rr.log(f"action_{idx}", rr.Scalar(val))
for idx, val in enumerate(d["observation.state"]):
rr.log(f"state_{idx}", rr.Scalar(val))

Some files were not shown because too many files have changed in this diff Show More