Merge remote-tracking branch 'upstream/main' into add_drop_last_keyframes

This commit is contained in:
Alexander Soare
2024-05-20 09:24:11 +01:00
77 changed files with 1810 additions and 880 deletions

View File

@@ -1,11 +1,15 @@
# What does this PR do?
## What this does
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples:
- Fixes # (issue)
- Adds new dataset
- Optimizes something
| Title | Label |
|----------------------|-----------------|
| Fixes #[issue] | (🐛 Bug) |
| Adds new dataset | (🗃️ Dataset) |
| Optimizes something | (⚡️ Performance) |
## How was it tested?
## How it was tested
Explain/show how you tested your changes.
Examples:
- Added `test_something` in `tests/test_stuff.py`.
@@ -13,6 +17,7 @@ Examples:
- Optimized `some_function`, it now runs X times faster than previously.
## How to checkout & try? (for the reviewer)
Provide a simple way for the reviewer to try out your changes.
Examples:
```bash
@@ -22,11 +27,8 @@ DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
python lerobot/scripts/train.py --some.option=true
```
## Before submitting
Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).

View File

@@ -57,6 +57,38 @@ jobs:
&& rm -rf tests/outputs outputs
pytest-minimal:
name: Pytest (minimal install)
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install poetry dependencies
run: |
poetry install --extras "test"
- name: Test with pytest
run: |
pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
end-to-end:
name: End-to-end
runs-on: ubuntu-latest

View File

@@ -18,7 +18,7 @@ repos:
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.3
hooks:
- id: ruff
args: [--fix]

View File

@@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
# ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-tdmpc-ete-train
${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train:
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \
env=xarm \
env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \
dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=2 \

206
README.md
View File

@@ -29,15 +29,15 @@
---
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
#### Examples of pretrained models and environments
#### Examples of pretrained models on simulation environments
<table>
<tr>
@@ -54,10 +54,11 @@
### Acknowledgment
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
## Installation
@@ -86,15 +87,18 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use:
pip install ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
```bash
wandb login
```
(note: you will also need to enable WandB in the configuration. See below.)
## Walkthrough
```
.
├── examples # contains demonstration examples, start here to learn about LeRobot
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
@@ -103,69 +107,84 @@ wandb login
| ├── common # contains classes and utilities
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | ├── envs # various sim environments: aloha, pusht, xarm
| | ── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line
| ├── visualize_dataset.py # load a dataset and render its demonstrations
| ├── eval.py # load policy and evaluate it on an environment
| ── train.py # train a policy via imitation learning and/or reinforcement learning
| | ── policies # various policies: act, diffusion, tdmpc
| | └── utils # various utilities
| └── scripts # contains functions to execute via command line
| ├── eval.py # load policy and evaluate it on an environment
| ── train.py # train a policy via imitation learning and/or reinforcement learning
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
| └── visualize_dataset.py # load a dataset and render its demonstrations
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
├── .github
| └── workflows
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
└── tests # contains pytest utilities for continuous integration
```
### Visualize datasets
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
Or you can achieve the same result by executing our script from the command line:
You can also locally visualize episodes from a dataset by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
env=pusht \
hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
--repo-id lerobot/pusht \
--episode-index 0
```
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
### Evaluate a pretrained policy
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
Or you can achieve the same result by executing our script from the command line:
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
```bash
python lerobot/scripts/eval.py \
-p lerobot/diffusion_pusht \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub
-p lerobot/diffusion_pusht \
eval.n_episodes=10 \
eval.batch_size=10
```
After training your own policy, you can also re-evaluate the checkpoints with:
Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
-p PATH/TO/TRAIN/OUTPUT/FOLDER \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir
-p PATH/TO/TRAIN/OUTPUT/FOLDER
```
See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
In general, you can use our training script to easily train any policy on any environment:
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash
python lerobot/scripts/train.py \
env=aloha \
task=sim_insertion \
repo_id=lerobot/aloha_sim_insertion_scripted \
policy=act \
hydra.run.dir=outputs/train/aloha_act
policy=act \
env=aloha \
env.task=AlohaInsertion-v0 \
dataset_repo_id=lerobot/aloha_sim_insertion_human \
```
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
```bash
hydra.run.dir=your/new/experiment/dir
```
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash
wandb.enable=true
```
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
## Contribute
@@ -173,98 +192,40 @@ If you would like to contribute to 🤗 LeRobot, please check out our [contribut
### Add a new dataset
```python
# TODO(rcadene, AdilZouitine): rewrite this section
```
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then you can upload it to the hub with:
Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.0
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id aloha_ping_ping \
--raw-format aloha_hdf5 \
--community-id lerobot
```
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
```bash
HF_USER=lerobot
DATASET=pusht
```
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
If you want to improve an existing dataset, you can download it locally with:
```bash
mkdir -p data/$DATASET
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
--repo-type dataset \
--local-dir data/$DATASET \
--local-dir-use-symlinks=False \
--revision v1.0
```
Iterate on your code and dataset with:
```bash
DATA_DIR=data python train.py
```
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.1 \
--delete "*"
```
Then you will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
Finally, you might want to mock the dataset if you need to update the unit tests as well:
```bash
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
### Add a pretrained policy
```python
# TODO(rcadene, alexander-soare): rewrite this section
```
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, run the following with a desired revision ID.
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, run the following:
```bash
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
```bash
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
```
See `eval.py` for an example of how a user may use your policy.
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
### Improve your code with profiling
@@ -291,9 +252,14 @@ with profile(
# insert code to profile, potentially whole body of eval_policy function
```
```bash
python lerobot/scripts/eval.py \
--config outputs/pusht/.hydra/config.yaml \
pretrained_model_path=outputs/pusht/model.pt \
eval_episodes=7
## Citation
If you want, you can cite this work with:
```
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
}
```

View File

@@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
git git-lfs openssh-client \
nano vim \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
@@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
COPY . /lerobot
RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"

View File

@@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
"""
from pathlib import Path
from pprint import pprint
import imageio
import torch
@@ -21,39 +22,36 @@ import torch
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
print("List of available datasets", lerobot.available_datasets)
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
print("List of available datasets:")
pprint(lerobot.available_datasets)
# Let's take one for this example
repo_id = "lerobot/pusht"
# You can easily load a dataset from a Hugging Face repositery
# You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id)
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# TODO(rcadene): update to make the print pretty
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets/index for more information).
print(dataset)
print(dataset.hf_dataset)
# and provides additional utilities for robotics and compatibility with pytorch
print(f"number of samples/frames: {dataset.num_samples=}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
# And provides additional utilities for robotics and compatibility with Pytorch
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# Access frame indexes associated to first episode
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
# Here we grab all the image frames.
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
# To visualize them, we convert to uint8 range [0,255]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
@@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
# Our datasets can load previous and future frames for each key/modality,
# using timestamps differences with the current loaded frame. For instance:
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
@@ -74,12 +72,12 @@ delta_timestamps = {
"action": [t / dataset.fps for t in range(64)],
}
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c)
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
# because they are just PyTorch datasets.
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,

View File

@@ -5,23 +5,108 @@ training outputs directory. In the latter case, you might want to run examples/3
from pathlib import Path
import gym_pusht # noqa: F401
import gymnasium as gym
import imageio
import numpy
import torch
from huggingface_hub import snapshot_download
from lerobot.scripts.eval import eval
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
# Get a pretrained policy from the hub.
pretrained_policy_name = "lerobot/diffusion_pusht"
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
# Create a directory to store the video of the evaluation
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda")
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
# Override some config parameters to do with evaluation.
overrides = [
"eval.n_episodes=10",
"eval.batch_size=10",
"device=cuda",
]
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
policy.to(device)
# Evaluate the policy and save the outputs including metrics and videos.
# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
eval(pretrained_policy_path=pretrained_policy_path)
# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
# also automatically stops running after 300 interactions/steps.
env = gym.make(
"gym_pusht/PushT-v0",
obs_type="pixels_agent_pos",
max_episode_steps=300,
)
# Reset the policy and environmens to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)
# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []
# Render frame of the initial state
frames.append(env.render())
step = 0
done = False
while not done:
# Prepare observation for the policy running in Pytorch
state = torch.from_numpy(numpy_observation["agent_pos"])
image = torch.from_numpy(numpy_observation["pixels"])
# Convert to float32 with image from channel first in [0,255]
# to channel last in [0,1]
state = state.to(torch.float32)
image = image.to(torch.float32) / 255
image = image.permute(2, 0, 1)
# Send data tensors from CPU to GPU
state = state.to(device, non_blocking=True)
image = image.to(device, non_blocking=True)
# Add extra (empty) batch dimension, required to forward the policy
state = state.unsqueeze(0)
image = image.unsqueeze(0)
# Create the policy input dictionary
observation = {
"observation.state": state,
"observation.image": image,
}
# Predict the next action with respect to the current observation
with torch.inference_mode():
action = policy.select_action(observation)
# Prepare the action for the environment
numpy_action = action.squeeze(0).to("cpu").numpy()
# Step through the environment and receive a new observation
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
print(f"{step=} {reward=} {terminated=}")
# Keep track of all the rewards and frames
rewards.append(reward)
frames.append(env.render())
# The rollout is considered done when the success state is reach (i.e. terminated is True),
# or the maximum number of iterations is reached (i.e. truncated is True)
done = terminated | truncated | done
step += 1
if terminated:
print("Success!")
else:
print("Failure!")
# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]
# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
print(f"Video of the evaluation is available in '{video_path}'.")

View File

@@ -4,36 +4,42 @@ Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils.utils import init_hydra_config
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
output_directory.mkdir(parents=True, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example.
# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
device = torch.device("cuda")
log_freq = 250
# Set up the dataset.
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
dataset = make_dataset(hydra_cfg)
delta_timestamps = {
# Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train()
policy.to(device)
@@ -69,7 +75,5 @@ while not done:
done = True
break
# Save the policy.
# Save a policy checkpoint.
policy.save_pretrained(output_directory)
# Save the Hydra configuration so we have the environment configuration for eval.
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
@@ -85,13 +100,6 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
)
available_policies = [
"act",
"diffusion",

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version

View File

@@ -37,16 +37,16 @@ How to decode videos?
## Variables
**Image content**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
**Requested timestamps**
In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
- `single_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
**Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
## Results

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import random
import shutil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
@@ -72,6 +87,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
@@ -86,15 +102,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.hf_dataset.features
@property
def image_keys(self) -> list[str]:
image_keys = []
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Image):
image_keys.append(key)
return image_keys + self.video_frame_keys
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self):
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame):
@@ -103,10 +126,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
"""Number of possible samples in the dataset.
This is equivalent to the number of frames in the dataset minus n_end_keyframes_dropped.
"""
return len(self.index)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return len(self.hf_dataset.unique("episode_index"))
@property
@@ -146,6 +174,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)
@classmethod
def from_preloaded(
cls,

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# imagecodecs/numcodecs.py
# Copyright (c) 2021-2022, Christoph Gohlke

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
"""
@@ -142,12 +157,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
image_keys = [key for key in data_dict if "observation.images." in key]
for image_key in image_keys:
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[image_key] = VideoFrame()
features[key] = VideoFrame()
else:
features[image_key] = Image()
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from math import ceil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
import shutil

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
import logging

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
import pickle

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
import warnings

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import einops
import numpy as np
import torch

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(rcadene, alexander-soare): clean this file
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
@@ -17,9 +32,10 @@ def log_output_dir(out_dir):
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"seed:{cfg.seed}",
]
@@ -81,9 +97,9 @@ class Logger:
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
@@ -93,9 +109,10 @@ class Logger:
self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
if self._wandb:
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer",
)
artifact.add_file(fp)
@@ -113,6 +130,11 @@ class Logger:
assert mode in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@@ -51,8 +66,12 @@ class ACTConfig:
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
environment step.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
actions for a given time step over multiple policy invocations. Updates are calculated as:
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@@ -100,6 +119,9 @@ class ACTConfig:
dim_feedforward: int = 3200
feedforward_activation: str = "relu"
n_encoder_layers: int = 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: int = 1
# VAE.
use_vae: bool = True
@@ -107,7 +129,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4
# Inference.
use_temporal_aggregation: bool = False
temporal_ensemble_momentum: float | None = None
# Training and loss computation.
dropout: float = 0.1
@@ -119,8 +141,11 @@ class ACTConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
@@ -130,10 +155,3 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
@@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
super().__init__()
if config is None:
config = ACTConfig()
self.config = config
self.config: ACTConfig = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
@@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None:
if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
@@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval()
batch = self.normalize_inputs(batch)
self._stack_images(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.config.n_action_steps]
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss}
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
@@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
@@ -161,10 +188,10 @@ class ACT(nn.Module):
│ encoder │ │ │ │Transf.│ │
│ │ │ │ │encoder│ │
└───▲─────┘ │ │ │ │ │
│ │ │ └──▲──┘ │
│ │ │
inputs └─────┼─────┘
│ │ │ └──▲──┘ │
│ │ │
inputs └─────┼──┘ │ image emb.
state emb.
└───────────────────────┘
"""
@@ -306,18 +333,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@@ -51,6 +67,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
@@ -64,6 +81,9 @@ class DiffusionConfig:
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
@@ -107,6 +127,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
@@ -118,23 +139,39 @@ class DiffusionConfig:
# Inference
num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if (
self.crop_shape[0] > self.input_shapes["observation.image"][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
'`input_shapes["observation.image"]`.'
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@@ -1,8 +1,24 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
"""
import math
@@ -10,12 +26,13 @@ from collections import deque
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
@@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset()
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
@@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
@@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
@@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
@@ -268,13 +304,84 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
@@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from omegaconf import DictConfig, OmegaConf
@@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
policy_cfg = policy_cfg_class(
**{
k: v
@@ -62,11 +79,18 @@ def make_policy(
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
# Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats)
else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, nn

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
@@ -38,7 +53,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@@ -47,7 +63,7 @@ class TDMPCConfig:
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
@@ -131,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references:
@@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
self.reset()
def reset(self):
"""
@@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
@@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)

View File

@@ -1,9 +1,28 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
def populate_queues(queues, batch):
for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen:

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import imageio

View File

@@ -1,8 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os.path as osp
import random
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Generator
import hydra
import numpy as np
@@ -39,6 +56,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -25,7 +25,7 @@ training:
eval_freq: ???
save_freq: ???
log_freq: 250
save_model: false
save_model: true
eval:
n_episodes: 1
@@ -35,7 +35,7 @@ eval:
use_async_envs: false
wandb:
enable: true
enable: false
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot

View File

@@ -3,6 +3,12 @@
seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
@@ -18,12 +24,6 @@ training:
grad_clip_norm: 10
online_steps_between_rollouts: 1
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
@@ -66,6 +66,9 @@ policy:
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
@@ -73,7 +76,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
use_temporal_aggregation: false
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@@ -7,6 +7,20 @@
seed: 100000
dataset_repo_id: lerobot/pusht
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
training:
offline_steps: 200000
online_steps: 0
@@ -44,20 +58,6 @@ eval:
n_episodes: 50
batch_size: 50
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy:
name: diffusion
@@ -95,6 +95,7 @@ policy:
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
@@ -105,3 +106,6 @@ policy:
# Inference
num_inference_steps: 100
# Loss computation
do_mask_loss_for_padding: false

View File

@@ -1,7 +1,7 @@
# @package _global_
seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay
dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import huggingface_hub

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics.
Usage examples:
@@ -583,17 +598,18 @@ if __name__ == "__main__":
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except HFValidationError:
logging.warning(
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
"Treating it as a local directory."
)
except RepositoryNotFoundError:
logging.warning(
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
"it as a local directory."
)
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
@@ -60,7 +75,7 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict
@@ -252,7 +267,7 @@ def main():
parser.add_argument(
"--revision",
type=str,
default="v1.2",
default=CODEBASE_VERSION,
help="Codebase version used to generate the dataset.",
)
parser.add_argument(

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from copy import deepcopy
@@ -8,7 +23,7 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -55,6 +70,8 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
@@ -71,6 +88,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.time()
policy.train()
output_dict = policy.forward(batch)
@@ -98,6 +116,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
return info
@@ -121,7 +140,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline):
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -289,7 +308,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None):
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
@@ -336,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step):
def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
@@ -392,9 +411,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
_maybe_eval_and_maybe_save(step + 1)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
@@ -460,9 +479,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.
_maybe_eval_and_maybe_save(step + 1)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
@@ -32,7 +47,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
@@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
"""
import argparse
import gc
import logging
import time
from pathlib import Path
@@ -115,15 +131,17 @@ def visualize_dataset(
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
if num_workers > 0:
# TODO(rcadene): fix data workers hanging when `rr.init` is called
logging.warning("If data loader is hanging, try `--num-workers 0`.")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
@@ -131,7 +149,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.image_keys:
for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
@@ -196,7 +214,7 @@ def main():
parser.add_argument(
"--num-workers",
type=int,
default=0,
default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(

BIN
media/wandb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

774
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,9 @@ version = "0.1.0"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [
"Rémi Cadène <re.cadene@gmail.com>",
"Simon Alibert <alibert.sim@gmail.com>",
"Alexander Soare <alexander.soare159@gmail.com>",
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
"Simon Alibert <alibert.sim@gmail.com>",
"Adil Zouitine <adilzouitinegm@gmail.com>",
"Thomas Wolf <thomaswolfcontact@gmail.com>",
]
@@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
wandb = "^0.16.3"
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
gdown = "^5.1.0"
hydra-core = "^1.3.2"
einops = "^0.8.0"
pymunk = "^6.6.0"
zarr = "^2.17.0"
numba = "^0.59.0"
termcolor = ">=2.4.0"
omegaconf = ">=2.3.0"
wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
gdown = ">=5.1.0"
hydra-core = ">=1.3.2"
einops = ">=0.8.0"
pymunk = ">=6.6.0"
zarr = ">=2.17.0"
numba = ">=0.59.0"
torch = "^2.2.1"
opencv-python = "^4.9.0.80"
opencv-python = ">=4.9.0"
diffusers = "^0.27.2"
torchvision = "^0.18.0"
h5py = "^3.10.0"
huggingface-hub = "^0.21.4"
robomimic = "0.2.0"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
gym-pusht = { version = "^0.1.0", optional = true}
gym-xarm = { version = "^0.1.0", optional = true}
gym-aloha = { version = "^0.1.0", optional = true}
pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
torchvision = ">=0.18.0"
h5py = ">=3.10.0"
huggingface-hub = ">=0.21.4"
gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1"
gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true}
pre-commit = {version = ">=3.7.0", optional = true}
debugpy = {version = ">=1.8.1", optional = true}
pytest = {version = ">=8.1.0", optional = true}
pytest-cov = {version = ">=5.0.0", optional = true}
datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true }
pyav = "^12.0.5"
moviepy = "^1.0.3"
rerun-sdk = "^0.15.1"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
[tool.poetry.extras]
@@ -104,5 +103,5 @@ ignore-init-module-imports = true
[build-system]
requires = ["poetry-core>=1.5.0"]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import DEVICE

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
@@ -15,7 +30,7 @@ from tests.utils import require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported — if they're installed — and that their
be successfully imported — if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from copy import deepcopy
@@ -41,7 +56,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
image_keys = dataset.image_keys
camera_keys = dataset.camera_keys
item = dataset[0]
@@ -71,7 +86,7 @@ def test_factory(env_name, repo_id, policy_name):
else:
assert item[key].ndim == ndim, f"{key}"
if key in image_keys:
if key in camera_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert item[key].max() <= 1.0, f"{key}"

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym

View File

@@ -1,8 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import subprocess
import sys
from pathlib import Path
from tests.utils import require_package
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
@@ -21,6 +38,7 @@ def test_example_1():
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
@require_package("gym_pusht")
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
@@ -46,7 +64,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
@@ -58,16 +76,16 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
('pretrained_policy_name = "lerobot/diffusion_pusht"', ""),
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("step += 1", "break"),
],
)
assert Path("outputs/train/example_pusht_diffusion").exists()
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from pathlib import Path
@@ -49,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
"act",
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
(
"aloha",
"diffusion",
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
],
)
@require_env
@@ -72,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides,
)
# Additional config override logic.
if env_name == "aloha" and policy_name == "diffusion":
for keys in [
("training", "delta_timestamps"),
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.images.top"] = dct["observation.image"]
del dct["observation.image"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Additional config override logic.
if env_name == "pusht" and policy_name == "act":
for keys in [
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.image"] = dct["observation.images.top"]
del dct["observation.images.top"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
@@ -236,7 +284,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",

38
tests/test_utils.py Normal file
View File

@@ -0,0 +1,38 @@
import random
from typing import Callable
import numpy as np
import pytest
import torch
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else [],
)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from functools import wraps
import pytest
import torch
@@ -61,7 +77,6 @@ def require_env(func):
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
@@ -82,3 +97,20 @@ def require_env(func):
return func(*args, **kwargs)
return wrapper
def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not is_package_available(package_name):
pytest.skip(f"{package_name} not installed")
return func(*args, **kwargs)
return wrapper
return decorator