Merge remote-tracking branch 'upstream/main' into add_drop_last_keyframes

This commit is contained in:
Alexander Soare
2024-05-20 09:24:11 +01:00
77 changed files with 1810 additions and 880 deletions

View File

@@ -1,11 +1,15 @@
# What does this PR do? ## What this does
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples: Examples:
- Fixes # (issue) | Title | Label |
- Adds new dataset |----------------------|-----------------|
- Optimizes something | Fixes #[issue] | (🐛 Bug) |
| Adds new dataset | (🗃️ Dataset) |
| Optimizes something | (⚡️ Performance) |
## How was it tested? ## How it was tested
Explain/show how you tested your changes.
Examples: Examples:
- Added `test_something` in `tests/test_stuff.py`. - Added `test_something` in `tests/test_stuff.py`.
@@ -13,6 +17,7 @@ Examples:
- Optimized `some_function`, it now runs X times faster than previously. - Optimized `some_function`, it now runs X times faster than previously.
## How to checkout & try? (for the reviewer) ## How to checkout & try? (for the reviewer)
Provide a simple way for the reviewer to try out your changes.
Examples: Examples:
```bash ```bash
@@ -22,11 +27,8 @@ DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
python lerobot/scripts/train.py --some.option=true python lerobot/scripts/train.py --some.option=true
``` ```
## Before submitting ## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr). **Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people. members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).

View File

@@ -57,6 +57,38 @@ jobs:
&& rm -rf tests/outputs outputs && rm -rf tests/outputs outputs
pytest-minimal:
name: Pytest (minimal install)
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install poetry dependencies
run: |
poetry install --extras "test"
- name: Test with pytest
run: |
pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
end-to-end: end-to-end:
name: End-to-end name: End-to-end
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -18,7 +18,7 @@ repos:
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2 rev: v0.4.3
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]

View File

@@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-eval
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval ${MAKE} test-default-ete-eval
test-act-ete-train: test-act-ete-train:
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \ env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=2 \

206
README.md
View File

@@ -29,15 +29,15 @@
--- ---
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models. 🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning. 🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there. 🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot) 🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
#### Examples of pretrained models and environments #### Examples of pretrained models on simulation environments
<table> <table>
<tr> <tr>
@@ -54,10 +54,11 @@
### Acknowledgment ### Acknowledgment
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/) - Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) - Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/) - Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl) - Thanks to Antonio Loquercio and Ashish Kumar for their early support.
## Installation ## Installation
@@ -86,15 +87,18 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use:
pip install ".[aloha, pusht]" pip install ".[aloha, pusht]"
``` ```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
```bash ```bash
wandb login wandb login
``` ```
(note: you will also need to enable WandB in the configuration. See below.)
## Walkthrough ## Walkthrough
``` ```
. .
├── examples # contains demonstration examples, start here to learn about LeRobot
├── lerobot ├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line | ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy | | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
@@ -103,69 +107,84 @@ wandb login
| ├── common # contains classes and utilities | ├── common # contains classes and utilities
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm | | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | ├── envs # various sim environments: aloha, pusht, xarm | | ├── envs # various sim environments: aloha, pusht, xarm
| | ── policies # various policies: act, diffusion, tdmpc | | ── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line | | └── utils # various utilities
| ├── visualize_dataset.py # load a dataset and render its demonstrations | └── scripts # contains functions to execute via command line
| ├── eval.py # load policy and evaluate it on an environment | ├── eval.py # load policy and evaluate it on an environment
| ── train.py # train a policy via imitation learning and/or reinforcement learning | ── train.py # train a policy via imitation learning and/or reinforcement learning
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
| └── visualize_dataset.py # load a dataset and render its demonstrations
├── outputs # contains results of scripts execution: logs, videos, model checkpoints ├── outputs # contains results of scripts execution: logs, videos, model checkpoints
├── .github
| └── workflows
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
└── tests # contains pytest utilities for continuous integration └── tests # contains pytest utilities for continuous integration
``` ```
### Visualize datasets ### Visualize datasets
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities. Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
Or you can achieve the same result by executing our script from the command line: You can also locally visualize episodes from a dataset by executing our script from the command line:
```bash ```bash
python lerobot/scripts/visualize_dataset.py \ python lerobot/scripts/visualize_dataset.py \
env=pusht \ --repo-id lerobot/pusht \
hydra.run.dir=outputs/visualize_dataset/example --episode-index 0
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
``` ```
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
### Evaluate a pretrained policy ### Evaluate a pretrained policy
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation. Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
Or you can achieve the same result by executing our script from the command line: We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
```bash ```bash
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p lerobot/diffusion_pusht \ -p lerobot/diffusion_pusht \
eval_episodes=10 \ eval.n_episodes=10 \
hydra.run.dir=outputs/eval/example_hub eval.batch_size=10
``` ```
After training your own policy, you can also re-evaluate the checkpoints with: Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash ```bash
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p PATH/TO/TRAIN/OUTPUT/FOLDER \ -p PATH/TO/TRAIN/OUTPUT/FOLDER
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir
``` ```
See `python lerobot/scripts/eval.py --help` for more instructions. See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy ### Train your own policy
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed. Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
In general, you can use our training script to easily train any policy on any environment: In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
env=aloha \ policy=act \
task=sim_insertion \ env=aloha \
repo_id=lerobot/aloha_sim_insertion_scripted \ env.task=AlohaInsertion-v0 \
policy=act \ dataset_repo_id=lerobot/aloha_sim_insertion_human \
hydra.run.dir=outputs/train/aloha_act
``` ```
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance. The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
```bash
hydra.run.dir=your/new/experiment/dir
```
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash
wandb.enable=true
```
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
## Contribute ## Contribute
@@ -173,98 +192,40 @@ If you would like to contribute to 🤗 LeRobot, please check out our [contribut
### Add a new dataset ### Add a new dataset
```python To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
# TODO(rcadene, AdilZouitine): rewrite this section
```
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash ```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
``` ```
Then you can upload it to the hub with: Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
```bash ```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ python lerobot/scripts/push_dataset_to_hub.py \
--repo-type dataset \ --data-dir data \
--revision v1.0 --dataset-id aloha_ping_ping \
--raw-format aloha_hdf5 \
--community-id lerobot
``` ```
You will need to set the corresponding version as a default argument in your dataset class: See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used: If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
```bash
HF_USER=lerobot
DATASET=pusht
```
If you want to improve an existing dataset, you can download it locally with:
```bash
mkdir -p data/$DATASET
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
--repo-type dataset \
--local-dir data/$DATASET \
--local-dir-use-symlinks=False \
--revision v1.0
```
Iterate on your code and dataset with:
```bash
DATA_DIR=data python train.py
```
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.1 \
--delete "*"
```
Then you will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
Finally, you might want to mock the dataset if you need to update the unit tests as well:
```bash
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
### Add a pretrained policy ### Add a pretrained policy
```python Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
# TODO(rcadene, alexander-soare): rewrite this section
```
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. - `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility. - `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, run the following with a desired revision ID.
To upload these to the hub, run the following:
```bash ```bash
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
``` ```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
```bash
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
```
See `eval.py` for an example of how a user may use your policy.
### Improve your code with profiling ### Improve your code with profiling
@@ -291,9 +252,14 @@ with profile(
# insert code to profile, potentially whole body of eval_policy function # insert code to profile, potentially whole body of eval_policy function
``` ```
```bash ## Citation
python lerobot/scripts/eval.py \
--config outputs/pusht/.hydra/config.yaml \ If you want, you can cite this work with:
pretrained_model_path=outputs/pusht/model.pt \ ```
eval_episodes=7 @misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
}
``` ```

View File

@@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies # Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \ build-essential cmake \
git git-lfs openssh-client \
nano vim \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
@@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot # Install LeRobot
COPY . /lerobot RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"

View File

@@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
""" """
from pathlib import Path from pathlib import Path
from pprint import pprint
import imageio import imageio
import torch import torch
@@ -21,39 +22,36 @@ import torch
import lerobot import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
print("List of available datasets", lerobot.available_datasets) print("List of available datasets:")
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted', pprint(lerobot.available_datasets)
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
# Let's take one for this example
repo_id = "lerobot/pusht" repo_id = "lerobot/pusht"
# You can easily load a dataset from a Hugging Face repositery # You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). # LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# TODO(rcadene): update to make the print pretty # (see https://huggingface.co/docs/datasets/index for more information).
print(f"{dataset=}") print(dataset)
print(f"{dataset.hf_dataset=}") print(dataset.hf_dataset)
# and provides additional utilities for robotics and compatibility with pytorch # And provides additional utilities for robotics and compatibility with Pytorch
print(f"number of samples/frames: {dataset.num_samples=}") print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}") print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}") print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# Access frame indexes associated to first episode # Access frame indexes associated to first episode
episode_index = 0 episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item() from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset. # LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# Here we grab all the image frames. # with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. # Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# To visualize them, we convert to uint8 range [0,255] # them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames] frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c). # and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
@@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load the history of past observations or trajectories of future actions. # For many machine learning applications we need to load the history of past observations or trajectories of
# Our datasets can load previous and future frames for each key/modality, # future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# using timestamps differences with the current loaded frame. For instance: # differences with the current loaded frame. For instance:
delta_timestamps = { delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0], "observation.image": [-1, -0.5, -0.20, 0],
@@ -74,12 +72,12 @@ delta_timestamps = {
"action": [t / dataset.fps for t in range(64)], "action": [t / dataset.fps for t in range(64)],
} }
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w) print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c) print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c) print(f"{dataset[0]['action'].shape=}\n") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers # Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# because they are just PyTorch datasets. # PyTorch datasets.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=0, num_workers=0,

View File

@@ -5,23 +5,108 @@ training outputs directory. In the latter case, you might want to run examples/3
from pathlib import Path from pathlib import Path
import gym_pusht # noqa: F401
import gymnasium as gym
import imageio
import numpy
import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from lerobot.scripts.eval import eval from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
# Get a pretrained policy from the hub. # Create a directory to store the video of the evaluation
pretrained_policy_name = "lerobot/diffusion_pusht" output_directory = Path("outputs/eval/example_pusht_diffusion")
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name)) output_directory.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda")
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder. # OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
# Override some config parameters to do with evaluation. policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
overrides = [ policy.eval()
"eval.n_episodes=10", policy.to(device)
"eval.batch_size=10",
"device=cuda",
]
# Evaluate the policy and save the outputs including metrics and videos. # Initialize evaluation environment to render two observation types:
# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout # an image of the scene and state/position of the agent. The environment
eval(pretrained_policy_path=pretrained_policy_path) # also automatically stops running after 300 interactions/steps.
env = gym.make(
"gym_pusht/PushT-v0",
obs_type="pixels_agent_pos",
max_episode_steps=300,
)
# Reset the policy and environmens to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)
# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []
# Render frame of the initial state
frames.append(env.render())
step = 0
done = False
while not done:
# Prepare observation for the policy running in Pytorch
state = torch.from_numpy(numpy_observation["agent_pos"])
image = torch.from_numpy(numpy_observation["pixels"])
# Convert to float32 with image from channel first in [0,255]
# to channel last in [0,1]
state = state.to(torch.float32)
image = image.to(torch.float32) / 255
image = image.permute(2, 0, 1)
# Send data tensors from CPU to GPU
state = state.to(device, non_blocking=True)
image = image.to(device, non_blocking=True)
# Add extra (empty) batch dimension, required to forward the policy
state = state.unsqueeze(0)
image = image.unsqueeze(0)
# Create the policy input dictionary
observation = {
"observation.state": state,
"observation.image": image,
}
# Predict the next action with respect to the current observation
with torch.inference_mode():
action = policy.select_action(observation)
# Prepare the action for the environment
numpy_action = action.squeeze(0).to("cpu").numpy()
# Step through the environment and receive a new observation
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
print(f"{step=} {reward=} {terminated=}")
# Keep track of all the rewards and frames
rewards.append(reward)
frames.append(env.render())
# The rollout is considered done when the success state is reach (i.e. terminated is True),
# or the maximum number of iterations is reached (i.e. truncated is True)
done = terminated | truncated | done
step += 1
if terminated:
print("Success!")
else:
print("Failure!")
# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]
# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
print(f"Video of the evaluation is available in '{video_path}'.")

View File

@@ -4,36 +4,42 @@ Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py examples/2_evaluate_pretrained_policy.py
""" """
import os
from pathlib import Path from pathlib import Path
import torch import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils.utils import init_hydra_config
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion") output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True) output_directory.mkdir(parents=True, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example. # Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. # Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000 training_steps = 5000
device = torch.device("cuda") device = torch.device("cuda")
log_freq = 250 log_freq = 250
# Set up the dataset. # Set up the dataset.
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"]) delta_timestamps = {
dataset = make_dataset(hydra_cfg) # Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Set up the the policy. # Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. # Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT. # For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults. # If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig() cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train() policy.train()
policy.to(device) policy.to(device)
@@ -69,7 +75,5 @@ while not done:
done = True done = True
break break
# Save the policy. # Save a policy checkpoint.
policy.save_pretrained(output_directory) policy.save_pretrained(output_directory)
# Save the Hydra configuration so we have the environment configuration for eval.
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
@@ -85,13 +100,6 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
) )
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
)
available_policies = [ available_policies = [
"act", "act",
"diffusion", "diffusion",

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To enable `lerobot.__version__`""" """To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version

View File

@@ -37,16 +37,16 @@ How to decode videos?
## Variables ## Variables
**Image content** **Image content**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor). We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
**Requested timestamps** **Requested timestamps**
In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings: In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
- `single_frame`: 1 frame, - `single_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`), - `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`), - `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
**Data augmentations** **Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.). We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
## Results ## Results

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import random import random
import shutil import shutil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import torch import torch

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from pathlib import Path from pathlib import Path
@@ -72,6 +87,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"] return self.info["fps"]
@property @property
@@ -86,15 +102,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.hf_dataset.features return self.hf_dataset.features
@property @property
def image_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
image_keys = [] """Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.hf_dataset.features.items(): for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Image): if isinstance(feats, (datasets.Image, VideoFrame)):
image_keys.append(key) keys.append(key)
return image_keys + self.video_frame_keys return keys
@property @property
def video_frame_keys(self): def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = [] video_frame_keys = []
for key, feats in self.hf_dataset.features.items(): for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame): if isinstance(feats, VideoFrame):
@@ -103,10 +126,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
"""Number of possible samples in the dataset.
This is equivalent to the number of frames in the dataset minus n_end_keyframes_dropped.
"""
return len(self.index) return len(self.index)
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes."""
return len(self.hf_dataset.unique("episode_index")) return len(self.hf_dataset.unique("episode_index"))
@property @property
@@ -146,6 +174,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)
@classmethod @classmethod
def from_preloaded( def from_preloaded(
cls, cls,

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) """Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script. Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains all obsolete download scripts. They are centralized here to not have to load This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets. useless dependencies when using datasets.

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# imagecodecs/numcodecs.py # imagecodecs/numcodecs.py
# Copyright (c) 2021-2022, Christoph Gohlke # Copyright (c) 2021-2022, Christoph Gohlke

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
""" """
@@ -142,12 +157,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
def to_hf_dataset(data_dict, video) -> Dataset: def to_hf_dataset(data_dict, video) -> Dataset:
features = {} features = {}
image_keys = [key for key in data_dict if "observation.images." in key] keys = [key for key in data_dict if "observation.images." in key]
for image_key in image_keys: for key in keys:
if video: if video:
features[image_key] = VideoFrame() features[key] = VideoFrame()
else: else:
features[image_key] = Image() features[key] = Image()
features["observation.state"] = Sequence( features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy from copy import deepcopy
from math import ceil from math import ceil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy""" """Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
import shutil import shutil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface""" """Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
import logging import logging

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm""" """Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
import pickle import pickle

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from pathlib import Path from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import subprocess import subprocess
import warnings import warnings

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import gymnasium as gym import gymnasium as gym

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import einops import einops
import numpy as np import numpy as np
import torch import torch

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(rcadene, alexander-soare): clean this file # TODO(rcadene, alexander-soare): clean this file
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py""" """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
@@ -17,9 +32,10 @@ def log_output_dir(out_dir):
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
lst = [ lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}", f"env:{cfg.env.name}",
f"seed:{cfg.seed}", f"seed:{cfg.seed}",
] ]
@@ -81,9 +97,9 @@ class Logger:
# Also save the full Hydra config for the env configuration. # Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml") OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model", type="model",
) )
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
@@ -93,9 +109,10 @@ class Logger:
self._buffer_dir.mkdir(parents=True, exist_ok=True) self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl" fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp) buffer.save(fp)
if self._wandb: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier), f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer", type="buffer",
) )
artifact.add_file(fp) artifact.add_file(fp)
@@ -113,6 +130,11 @@ class Logger:
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
if self._wandb is not None: if self._wandb is not None:
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -51,8 +66,12 @@ class ACTConfig:
documentation in the policy class). documentation in the policy class).
latent_dim: The VAE's latent dimension. latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
environment step. actions for a given time step over multiple policy invocations. Updates are calculated as:
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@@ -100,6 +119,9 @@ class ACTConfig:
dim_feedforward: int = 3200 dim_feedforward: int = 3200
feedforward_activation: str = "relu" feedforward_activation: str = "relu"
n_encoder_layers: int = 4 n_encoder_layers: int = 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: int = 1 n_decoder_layers: int = 1
# VAE. # VAE.
use_vae: bool = True use_vae: bool = True
@@ -107,7 +129,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4 n_vae_encoder_layers: int = 4
# Inference. # Inference.
use_temporal_aggregation: bool = False temporal_ensemble_momentum: float | None = None
# Training and loss computation. # Training and loss computation.
dropout: float = 0.1 dropout: float = 0.1
@@ -119,8 +141,11 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.use_temporal_aggregation: if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
raise NotImplementedError("Temporal aggregation is not yet implemented.") raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
)
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
@@ -130,10 +155,3 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action Chunking Transformer Policy """Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
@@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
super().__init__() super().__init__()
if config is None: if config is None:
config = ACTConfig() config = ACTConfig()
self.config = config self.config: ACTConfig = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats config.input_shapes, config.input_normalization_modes, dataset_stats
) )
@@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
self.model = ACT(config) self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None: if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad @torch.no_grad
@@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
self._stack_images(batch) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue actions = self.model(batch)[0][:, : self.config.n_action_steps]
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary? # TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
@@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
mean_kld = ( mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
) )
loss_dict["kld_loss"] = mean_kld loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else: else:
loss_dict["loss"] = l1_loss loss_dict["loss"] = l1_loss
return loss_dict return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.
@@ -161,10 +188,10 @@ class ACT(nn.Module):
│ encoder │ │ │ │Transf.│ │ │ encoder │ │ │ │Transf.│ │
│ │ │ │ │encoder│ │ │ │ │ │ │encoder│ │
└───▲─────┘ │ │ │ │ │ └───▲─────┘ │ │ │ │ │
│ │ │ └──▲──┘ │ │ │ │ └──▲──┘ │
│ │ │ │ │ │
inputs └─────┼─────┘ inputs └─────┼──┘ │ image emb.
state emb.
└───────────────────────┘ └───────────────────────┘
""" """
@@ -306,18 +333,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension. # Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3) encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack([latent_embed, robot_state_embed], axis=0), torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1), einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
] ]
) )
pos_embed = torch.cat( pos_embed = torch.cat(

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -51,6 +67,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation. modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step. beta_start: Beta value for the first forward-diffusion step.
@@ -64,6 +81,9 @@ class DiffusionConfig:
clip_sample_range: The magnitude of the clipping range as described above. clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`. spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
""" """
# Inputs / output structure. # Inputs / output structure.
@@ -107,6 +127,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128 diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True use_film_scale_modulation: bool = True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100 num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2" beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001 beta_start: float = 0.0001
@@ -118,23 +139,39 @@ class DiffusionConfig:
# Inference # Inference
num_inference_steps: int | None = None num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if ( if (
self.crop_shape[0] > self.input_shapes["observation.image"][1] self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2] or self.crop_shape[1] > self.input_shapes[image_key][2]
): ):
raise ValueError( raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' f"for `crop_shape` and {self.input_shapes[image_key]} for "
'`input_shapes["observation.image"]`.' "`input_shapes[{image_key}]`."
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:
raise ValueError( raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
) )
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@@ -1,8 +1,24 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
""" """
import math import math
@@ -10,12 +26,13 @@ from collections import deque
from typing import Callable from typing import Callable
import einops import einops
import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
@@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset()
def reset(self): def reset(self):
""" """Clear observation and action queues. Should be called on `env.reset()`"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps), "observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
@@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
@@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps, * config.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps, num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start, beta_start=config.beta_start,
beta_end=config.beta_end, beta_end=config.beta_end,
beta_schedule=config.beta_schedule, beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample, clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range, clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type, prediction_type=config.prediction_type,
@@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have (at least): This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.image": (B, n_obs_steps, C, H, W)
} }
""" """
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
@@ -268,13 +304,84 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory). # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch: if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"] in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1) loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean() return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module): class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector. """Encoder an RGB image into a 1D feature vector.
@@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple( dummy_feature_map = self.backbone(dummy_input)
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] feature_map_shape = tuple(dummy_feature_map.shape[1:])
) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import logging
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
@@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset( if not set(hydra_cfg.policy).issuperset(expected_kwargs):
expected_kwargs logging.warning(
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
policy_cfg = policy_cfg_class( policy_cfg = policy_cfg_class(
**{ **{
k: v k: v
@@ -62,11 +79,18 @@ def make_policy(
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None: if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) # Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats) policy = policy_cls(policy_cfg, dataset_stats)
else: else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow. """A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
@@ -38,7 +53,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict: def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation. """Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information. Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
""" """
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -47,7 +63,7 @@ class TDMPCConfig:
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM. elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ. parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation. is applied. Note that the input images are assumed to be square for this augmentation.
@@ -131,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: # There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift # TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed. # augmentation. It should be able to be removed.
raise ValueError( raise ValueError(
"Only square images are handled now. Got image shape " f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
f"{self.input_shapes['observation.image']}."
) )
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World. """Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references: The comments in this code may sometimes refer to these references:
@@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
def save(self, fp): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
"""Save state dict of TOLD model to filepath.""" # Note: This check is covered in the post-init of the config but have a sanity check just in case.
torch.save(self.state_dict(), fp) assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def load(self, fp): self.reset()
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
def reset(self): def reset(self):
""" """
@@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
info = {} info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b) # (b, t) -> (t, b)
for key in batch: for key in batch:
if batch[key].ndim > 1: if batch[key].ndim > 1:
@@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)

View File

@@ -1,9 +1,28 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from torch import nn from torch import nn
def populate_queues(queues, batch): def populate_queues(queues, batch):
for key in batch: for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen: if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full # initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen: while len(queues[key]) != queues[key].maxlen:

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import logging import logging

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings import warnings
import imageio import imageio

View File

@@ -1,8 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Generator
import hydra import hydra
import numpy as np import numpy as np
@@ -39,6 +56,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging(): def init_logging():
def custom_format(record): def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -25,7 +25,7 @@ training:
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???
log_freq: 250 log_freq: 250
save_model: false save_model: true
eval: eval:
n_episodes: 1 n_episodes: 1
@@ -35,7 +35,7 @@ eval:
use_async_envs: false use_async_envs: false
wandb: wandb:
enable: true enable: false
# Set to true to disable saving an artifact despite save_model == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: false
project: lerobot project: lerobot

View File

@@ -3,6 +3,12 @@
seed: 1000 seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human dataset_repo_id: lerobot/aloha_sim_insertion_human
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training: training:
offline_steps: 80000 offline_steps: 80000
online_steps: 0 online_steps: 0
@@ -18,12 +24,6 @@ training:
grad_clip_norm: 10 grad_clip_norm: 10
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps: delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]" action: "[i / ${fps} for i in range(${policy.chunk_size})]"
@@ -66,6 +66,9 @@ policy:
dim_feedforward: 3200 dim_feedforward: 3200
feedforward_activation: relu feedforward_activation: relu
n_encoder_layers: 4 n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1 n_decoder_layers: 1
# VAE. # VAE.
use_vae: true use_vae: true
@@ -73,7 +76,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
use_temporal_aggregation: false temporal_ensemble_momentum: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@@ -7,6 +7,20 @@
seed: 100000 seed: 100000
dataset_repo_id: lerobot/pusht dataset_repo_id: lerobot/pusht
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
training: training:
offline_steps: 200000 offline_steps: 200000
online_steps: 0 online_steps: 0
@@ -44,20 +58,6 @@ eval:
n_episodes: 50 n_episodes: 50
batch_size: 50 batch_size: 50
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy: policy:
name: diffusion name: diffusion
@@ -95,6 +95,7 @@ policy:
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 128
use_film_scale_modulation: True use_film_scale_modulation: True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100 num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2 beta_schedule: squaredcos_cap_v2
beta_start: 0.0001 beta_start: 0.0001
@@ -105,3 +106,6 @@ policy:
# Inference # Inference
num_inference_steps: 100 num_inference_steps: 100
# Loss computation
do_mask_loss_for_padding: false

View File

@@ -1,7 +1,7 @@
# @package _global_ # @package _global_
seed: 1 seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform import platform
import huggingface_hub import huggingface_hub

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics. """Evaluate a policy on an environment by running rollouts and computing metrics.
Usage examples: Usage examples:
@@ -583,17 +598,18 @@ if __name__ == "__main__":
pretrained_policy_path = Path( pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision) snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
) )
except HFValidationError: except (HFValidationError, RepositoryNotFoundError) as e:
logging.warning( if isinstance(e, HFValidationError):
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. " error_message = (
"Treating it as a local directory." "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
) )
except RepositoryNotFoundError: else:
logging.warning( error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating " "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
"it as a local directory." )
)
pretrained_policy_path = Path(args.pretrained_policy_name_or_path) logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError( raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub, Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
@@ -60,7 +75,7 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict from lerobot.common.datasets.utils import flatten_dict
@@ -252,7 +267,7 @@ def main():
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
default="v1.2", default=CODEBASE_VERSION,
help="Codebase version used to generate the dataset.", help="Codebase version used to generate the dataset.",
) )
parser.add_argument( parser.add_argument(

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import time import time
from copy import deepcopy from copy import deepcopy
@@ -8,7 +23,7 @@ import hydra
import torch import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@@ -55,6 +70,8 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
cfg.training.lr_scheduler, cfg.training.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
@@ -71,6 +88,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.time()
policy.train() policy.train()
output_dict = policy.forward(batch) output_dict = policy.forward(batch)
@@ -98,6 +116,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
} }
return info return info
@@ -121,7 +140,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name) train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline): def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"] loss = info["loss"]
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
lr = info["lr"] lr = info["lr"]
@@ -289,7 +308,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset) sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
@@ -336,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops. # Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( eval_info = eval_policy(
@@ -392,9 +411,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# step + 1. # so we pass in step + 1.
_maybe_eval_and_maybe_save(step + 1) evaluate_and_checkpoint_if_needed(step + 1)
step += 1 step += 1
@@ -460,9 +479,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# in step + 1. # so we pass in step + 1.
_maybe_eval_and_maybe_save(step + 1) evaluate_and_checkpoint_if_needed(step + 1)
step += 1 step += 1
online_step += 1 online_step += 1

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state. Note: The last frame of the episode doesnt always correspond to a final state.
@@ -32,7 +47,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
``` ```
- Visualize data stored on a distant machine through streaming: - Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with (You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`) `ssh -L 9087:localhost:9087 username@remote-host`)
``` ```
distant$ python lerobot/scripts/visualize_dataset.py \ distant$ python lerobot/scripts/visualize_dataset.py \
@@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
""" """
import argparse import argparse
import gc
import logging import logging
import time import time
from pathlib import Path from pathlib import Path
@@ -115,15 +131,17 @@ def visualize_dataset(
spawn_local_viewer = mode == "local" and not save spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant": if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun") logging.info("Logging to Rerun")
if num_workers > 0:
# TODO(rcadene): fix data workers hanging when `rr.init` is called
logging.warning("If data loader is hanging, try `--num-workers 0`.")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)): for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch # iterate over the batch
for i in range(len(batch["index"])): for i in range(len(batch["index"])):
@@ -131,7 +149,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image # display each camera image
for key in dataset.image_keys: for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless? # TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
@@ -196,7 +214,7 @@ def main():
parser.add_argument( parser.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=0, default=4,
help="Number of processes of Dataloader for loading the data.", help="Number of processes of Dataloader for loading the data.",
) )
parser.add_argument( parser.add_argument(

BIN
media/wandb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

774
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,9 @@ version = "0.1.0"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [ authors = [
"Rémi Cadène <re.cadene@gmail.com>", "Rémi Cadène <re.cadene@gmail.com>",
"Simon Alibert <alibert.sim@gmail.com>",
"Alexander Soare <alexander.soare159@gmail.com>", "Alexander Soare <alexander.soare159@gmail.com>",
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>", "Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
"Simon Alibert <alibert.sim@gmail.com>",
"Adil Zouitine <adilzouitinegm@gmail.com>", "Adil Zouitine <adilzouitinegm@gmail.com>",
"Thomas Wolf <thomaswolfcontact@gmail.com>", "Thomas Wolf <thomaswolfcontact@gmail.com>",
] ]
@@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.10,<3.13" python = ">=3.10,<3.13"
termcolor = "^2.4.0" termcolor = ">=2.4.0"
omegaconf = "^2.3.0" omegaconf = ">=2.3.0"
wandb = "^0.16.3" wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = "^2.34.0"} imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
gdown = "^5.1.0" gdown = ">=5.1.0"
hydra-core = "^1.3.2" hydra-core = ">=1.3.2"
einops = "^0.8.0" einops = ">=0.8.0"
pymunk = "^6.6.0" pymunk = ">=6.6.0"
zarr = "^2.17.0" zarr = ">=2.17.0"
numba = "^0.59.0" numba = ">=0.59.0"
torch = "^2.2.1" torch = "^2.2.1"
opencv-python = "^4.9.0.80" opencv-python = ">=4.9.0"
diffusers = "^0.27.2" diffusers = "^0.27.2"
torchvision = "^0.18.0" torchvision = ">=0.18.0"
h5py = "^3.10.0" h5py = ">=3.10.0"
huggingface-hub = "^0.21.4" huggingface-hub = ">=0.21.4"
robomimic = "0.2.0" gymnasium = ">=0.29.1"
gymnasium = "^0.29.1" cmake = ">=3.29.0.1"
cmake = "^3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true}
gym-pusht = { version = "^0.1.0", optional = true} gym-xarm = { version = ">=0.1.1", optional = true}
gym-xarm = { version = "^0.1.0", optional = true} gym-aloha = { version = ">=0.1.1", optional = true}
gym-aloha = { version = "^0.1.0", optional = true} pre-commit = {version = ">=3.7.0", optional = true}
pre-commit = {version = "^3.7.0", optional = true} debugpy = {version = ">=1.8.1", optional = true}
debugpy = {version = "^1.8.1", optional = true} pytest = {version = ">=8.1.0", optional = true}
pytest = {version = "^8.1.0", optional = true} pytest-cov = {version = ">=5.0.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0" datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true } imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = "^12.0.5" pyav = ">=12.0.5"
moviepy = "^1.0.3" moviepy = ">=1.0.3"
rerun-sdk = "^0.15.1" rerun-sdk = ">=0.15.1"
[tool.poetry.extras] [tool.poetry.extras]
@@ -104,5 +103,5 @@ ignore-init-module-imports = true
[build-system] [build-system]
requires = ["poetry-core>=1.5.0"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import DEVICE from .utils import DEVICE

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil import shutil
from pathlib import Path from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import gymnasium as gym import gymnasium as gym
@@ -15,7 +30,7 @@ from tests.utils import require_env
def test_available_env_task(env_name: str, task_name: list): def test_available_env_task(env_name: str, task_name: list):
""" """
This test verifies that all environments listed in `lerobot/__init__.py` can This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported — if they're installed — and that their be successfully imported — if they're installed — and that their
`available_tasks_per_env` are valid. `available_tasks_per_env` are valid.
""" """
package_name = f"gym_{env_name}" package_name = f"gym_{env_name}"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
@@ -41,7 +56,7 @@ def test_factory(env_name, repo_id, policy_name):
) )
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps delta_timestamps = dataset.delta_timestamps
image_keys = dataset.image_keys camera_keys = dataset.camera_keys
item = dataset[0] item = dataset[0]
@@ -71,7 +86,7 @@ def test_factory(env_name, repo_id, policy_name):
else: else:
assert item[key].ndim == ndim, f"{key}" assert item[key].ndim == ndim, f"{key}"
if key in image_keys: if key in camera_keys:
assert item[key].dtype == torch.float32, f"{key}" assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model # TODO(rcadene): we assume for now that image normalization takes place in the model
assert item[key].max() <= 1.0, f"{key}" assert item[key].max() <= 1.0, f"{key}"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import gymnasium as gym import gymnasium as gym

View File

@@ -1,8 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests # TODO(aliberts): Mute logging for these tests
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from tests.utils import require_package
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces: for f, r in finds_and_replaces:
@@ -21,6 +38,7 @@ def test_example_1():
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists() assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
@require_package("gym_pusht")
def test_examples_3_and_2(): def test_examples_3_and_2():
""" """
Train a model with example 3, check the outputs. Train a model with example 3, check the outputs.
@@ -46,7 +64,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {}) exec(file_contents, {})
for file_name in ["model.safetensors", "config.json", "config.yaml"]: for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py" path = "examples/2_evaluate_pretrained_policy.py"
@@ -58,16 +76,16 @@ def test_examples_3_and_2():
file_contents = _find_and_replace( file_contents = _find_and_replace(
file_contents, file_contents,
[ [
('pretrained_policy_name = "lerobot/diffusion_pusht"', ""), ('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
( (
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
), ),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
('"eval.batch_size=10"', '"eval.batch_size=1"'), ("step += 1", "break"),
('"device=cuda"', '"device=cpu"'),
], ],
) )
assert Path("outputs/train/example_pusht_diffusion").exists() exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
from pathlib import Path from pathlib import Path
@@ -49,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
"act", "act",
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
), ),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
(
"aloha",
"diffusion",
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
], ],
) )
@require_env @require_env
@@ -72,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides, + extra_overrides,
) )
# Additional config override logic.
if env_name == "aloha" and policy_name == "diffusion":
for keys in [
("training", "delta_timestamps"),
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.images.top"] = dct["observation.image"]
del dct["observation.image"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Additional config override logic.
if env_name == "pusht" and policy_name == "act":
for keys in [
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.image"] = dct["observation.images.top"]
del dct["observation.images.top"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
@@ -236,7 +284,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, policy_name, extra_overrides", "env_name, policy_name, extra_overrides",
[ [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), ("xarm", "tdmpc", []),
( (
"pusht", "pusht",
"diffusion", "diffusion",

38
tests/test_utils.py Normal file
View File

@@ -0,0 +1,38 @@
import random
from typing import Callable
import numpy as np
import pytest
import torch
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else [],
)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset from lerobot.scripts.visualize_dataset import visualize_dataset

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform import platform
from functools import wraps
import pytest import pytest
import torch import torch
@@ -61,7 +77,6 @@ def require_env(func):
Decorator that skips the test if the required environment package is not installed. Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument. As it need 'env_name' in args, it also checks whether it is provided as an argument.
""" """
from functools import wraps
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@@ -82,3 +97,20 @@ def require_env(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not is_package_available(package_name):
pytest.skip(f"{package_name} not installed")
return func(*args, **kwargs)
return wrapper
return decorator