Merge remote-tracking branch 'upstream/main' into add_drop_last_keyframes
This commit is contained in:
26
.github/PULL_REQUEST_TEMPLATE.md
vendored
26
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,11 +1,15 @@
|
|||||||
# What does this PR do?
|
## What this does
|
||||||
|
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
- Fixes # (issue)
|
| Title | Label |
|
||||||
- Adds new dataset
|
|----------------------|-----------------|
|
||||||
- Optimizes something
|
| Fixes #[issue] | (🐛 Bug) |
|
||||||
|
| Adds new dataset | (🗃️ Dataset) |
|
||||||
|
| Optimizes something | (⚡️ Performance) |
|
||||||
|
|
||||||
## How was it tested?
|
## How it was tested
|
||||||
|
Explain/show how you tested your changes.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
- Added `test_something` in `tests/test_stuff.py`.
|
- Added `test_something` in `tests/test_stuff.py`.
|
||||||
@@ -13,6 +17,7 @@ Examples:
|
|||||||
- Optimized `some_function`, it now runs X times faster than previously.
|
- Optimized `some_function`, it now runs X times faster than previously.
|
||||||
|
|
||||||
## How to checkout & try? (for the reviewer)
|
## How to checkout & try? (for the reviewer)
|
||||||
|
Provide a simple way for the reviewer to try out your changes.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
```bash
|
```bash
|
||||||
@@ -22,11 +27,8 @@ DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
|||||||
python lerobot/scripts/train.py --some.option=true
|
python lerobot/scripts/train.py --some.option=true
|
||||||
```
|
```
|
||||||
|
|
||||||
## Before submitting
|
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||||
Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||||
|
|
||||||
|
|
||||||
## Who can review?
|
|
||||||
|
|
||||||
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
|
||||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
||||||
|
|
||||||
|
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||||
|
|||||||
32
.github/workflows/test.yml
vendored
32
.github/workflows/test.yml
vendored
@@ -57,6 +57,38 @@ jobs:
|
|||||||
&& rm -rf tests/outputs outputs
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
|
|
||||||
|
pytest-minimal:
|
||||||
|
name: Pytest (minimal install)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DATA_DIR: tests/data
|
||||||
|
MUJOCO_GL: egl
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install poetry
|
||||||
|
run: |
|
||||||
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
|
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install poetry dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --extras "test"
|
||||||
|
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
pytest tests -v --cov=./lerobot --durations=0 \
|
||||||
|
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||||
|
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||||
|
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||||
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
|
|
||||||
end-to-end:
|
end-to-end:
|
||||||
name: End-to-end
|
name: End-to-end
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.4.2
|
rev: v0.4.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
7
Makefile
7
Makefile
@@ -22,9 +22,8 @@ test-end-to-end:
|
|||||||
${MAKE} test-act-ete-eval
|
${MAKE} test-act-ete-eval
|
||||||
${MAKE} test-diffusion-ete-train
|
${MAKE} test-diffusion-ete-train
|
||||||
${MAKE} test-diffusion-ete-eval
|
${MAKE} test-diffusion-ete-eval
|
||||||
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
|
${MAKE} test-tdmpc-ete-train
|
||||||
# ${MAKE} test-tdmpc-ete-train
|
${MAKE} test-tdmpc-ete-eval
|
||||||
# ${MAKE} test-tdmpc-ete-eval
|
|
||||||
${MAKE} test-default-ete-eval
|
${MAKE} test-default-ete-eval
|
||||||
|
|
||||||
test-act-ete-train:
|
test-act-ete-train:
|
||||||
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
|
|||||||
policy=tdmpc \
|
policy=tdmpc \
|
||||||
env=xarm \
|
env=xarm \
|
||||||
env.task=XarmLift-v0 \
|
env.task=XarmLift-v0 \
|
||||||
dataset_repo_id=lerobot/xarm_lift_medium_replay \
|
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=2 \
|
training.online_steps=2 \
|
||||||
|
|||||||
206
README.md
206
README.md
@@ -29,15 +29,15 @@
|
|||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||||
|
|
||||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||||
|
|
||||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||||
|
|
||||||
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||||
|
|
||||||
#### Examples of pretrained models and environments
|
#### Examples of pretrained models on simulation environments
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
@@ -54,10 +54,11 @@
|
|||||||
|
|
||||||
### Acknowledgment
|
### Acknowledgment
|
||||||
|
|
||||||
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||||
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||||
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||||
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
|
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||||
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -86,15 +87,18 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
|||||||
pip install ".[aloha, pusht]"
|
pip install ".[aloha, pusht]"
|
||||||
```
|
```
|
||||||
|
|
||||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||||
```bash
|
```bash
|
||||||
wandb login
|
wandb login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
(note: you will also need to enable WandB in the configuration. See below.)
|
||||||
|
|
||||||
## Walkthrough
|
## Walkthrough
|
||||||
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
|
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||||
├── lerobot
|
├── lerobot
|
||||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||||
@@ -103,69 +107,84 @@ wandb login
|
|||||||
| ├── common # contains classes and utilities
|
| ├── common # contains classes and utilities
|
||||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||||
| | └── policies # various policies: act, diffusion, tdmpc
|
| | ├── policies # various policies: act, diffusion, tdmpc
|
||||||
| └── scripts # contains functions to execute via command line
|
| | └── utils # various utilities
|
||||||
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
| └── scripts # contains functions to execute via command line
|
||||||
| ├── eval.py # load policy and evaluate it on an environment
|
| ├── eval.py # load policy and evaluate it on an environment
|
||||||
| └── train.py # train a policy via imitation learning and/or reinforcement learning
|
| ├── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||||
|
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
|
||||||
|
| └── visualize_dataset.py # load a dataset and render its demonstrations
|
||||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||||
├── .github
|
|
||||||
| └── workflows
|
|
||||||
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
|
|
||||||
└── tests # contains pytest utilities for continuous integration
|
└── tests # contains pytest utilities for continuous integration
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Visualize datasets
|
### Visualize datasets
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
|
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
|
||||||
|
|
||||||
Or you can achieve the same result by executing our script from the command line:
|
You can also locally visualize episodes from a dataset by executing our script from the command line:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
env=pusht \
|
--repo-id lerobot/pusht \
|
||||||
hydra.run.dir=outputs/visualize_dataset/example
|
--episode-index 0
|
||||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
|
||||||
|
|
||||||
|
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||||
|
|
||||||
|
|
||||||
|
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
|
||||||
|
|
||||||
### Evaluate a pretrained policy
|
### Evaluate a pretrained policy
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
|
||||||
|
|
||||||
Or you can achieve the same result by executing our script from the command line:
|
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p lerobot/diffusion_pusht \
|
-p lerobot/diffusion_pusht \
|
||||||
eval_episodes=10 \
|
eval.n_episodes=10 \
|
||||||
hydra.run.dir=outputs/eval/example_hub
|
eval.batch_size=10
|
||||||
```
|
```
|
||||||
|
|
||||||
After training your own policy, you can also re-evaluate the checkpoints with:
|
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p PATH/TO/TRAIN/OUTPUT/FOLDER \
|
-p PATH/TO/TRAIN/OUTPUT/FOLDER
|
||||||
eval_episodes=10 \
|
|
||||||
hydra.run.dir=outputs/eval/example_dir
|
|
||||||
```
|
```
|
||||||
|
|
||||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||||
|
|
||||||
### Train your own policy
|
### Train your own policy
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
|
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
|
||||||
|
|
||||||
In general, you can use our training script to easily train any policy on any environment:
|
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
env=aloha \
|
policy=act \
|
||||||
task=sim_insertion \
|
env=aloha \
|
||||||
repo_id=lerobot/aloha_sim_insertion_scripted \
|
env.task=AlohaInsertion-v0 \
|
||||||
policy=act \
|
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||||
hydra.run.dir=outputs/train/aloha_act
|
|
||||||
```
|
```
|
||||||
|
|
||||||
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
|
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||||
|
```bash
|
||||||
|
hydra.run.dir=your/new/experiment/dir
|
||||||
|
```
|
||||||
|
|
||||||
|
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
@@ -173,98 +192,40 @@ If you would like to contribute to 🤗 LeRobot, please check out our [contribut
|
|||||||
|
|
||||||
### Add a new dataset
|
### Add a new dataset
|
||||||
|
|
||||||
```python
|
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||||
# TODO(rcadene, AdilZouitine): rewrite this section
|
|
||||||
```
|
|
||||||
|
|
||||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then you can upload it to the hub with:
|
Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
|
||||||
```bash
|
```bash
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--repo-type dataset \
|
--data-dir data \
|
||||||
--revision v1.0
|
--dataset-id aloha_ping_ping \
|
||||||
|
--raw-format aloha_hdf5 \
|
||||||
|
--community-id lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
You will need to set the corresponding version as a default argument in your dataset class:
|
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||||
```python
|
|
||||||
version: str | None = "v1.1",
|
|
||||||
```
|
|
||||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
|
||||||
|
|
||||||
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
|
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
|
||||||
```bash
|
|
||||||
HF_USER=lerobot
|
|
||||||
DATASET=pusht
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want to improve an existing dataset, you can download it locally with:
|
|
||||||
```bash
|
|
||||||
mkdir -p data/$DATASET
|
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
|
|
||||||
--repo-type dataset \
|
|
||||||
--local-dir data/$DATASET \
|
|
||||||
--local-dir-use-symlinks=False \
|
|
||||||
--revision v1.0
|
|
||||||
```
|
|
||||||
|
|
||||||
Iterate on your code and dataset with:
|
|
||||||
```bash
|
|
||||||
DATA_DIR=data python train.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
|
|
||||||
```bash
|
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
|
||||||
--repo-type dataset \
|
|
||||||
--revision v1.1 \
|
|
||||||
--delete "*"
|
|
||||||
```
|
|
||||||
|
|
||||||
Then you will need to set the corresponding version as a default argument in your dataset class:
|
|
||||||
```python
|
|
||||||
version: str | None = "v1.1",
|
|
||||||
```
|
|
||||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
|
||||||
|
|
||||||
|
|
||||||
Finally, you might want to mock the dataset if you need to update the unit tests as well:
|
|
||||||
```bash
|
|
||||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
|
||||||
```
|
|
||||||
|
|
||||||
### Add a pretrained policy
|
### Add a pretrained policy
|
||||||
|
|
||||||
```python
|
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||||
# TODO(rcadene, alexander-soare): rewrite this section
|
|
||||||
```
|
|
||||||
|
|
||||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
|
||||||
|
|
||||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
|
||||||
|
|
||||||
Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
|
|
||||||
|
|
||||||
|
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
|
||||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||||
- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||||
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||||
|
|
||||||
To upload these to the hub, run the following with a desired revision ID.
|
|
||||||
|
|
||||||
|
To upload these to the hub, run the following:
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
|
huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||||
|
|
||||||
```bash
|
|
||||||
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
|
|
||||||
```
|
|
||||||
|
|
||||||
See `eval.py` for an example of how a user may use your policy.
|
|
||||||
|
|
||||||
|
|
||||||
### Improve your code with profiling
|
### Improve your code with profiling
|
||||||
@@ -291,9 +252,14 @@ with profile(
|
|||||||
# insert code to profile, potentially whole body of eval_policy function
|
# insert code to profile, potentially whole body of eval_policy function
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
## Citation
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config outputs/pusht/.hydra/config.yaml \
|
If you want, you can cite this work with:
|
||||||
pretrained_model_path=outputs/pusht/model.pt \
|
```
|
||||||
eval_episodes=7
|
@misc{cadene2024lerobot,
|
||||||
|
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||||
|
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||||
|
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||||
|
year = {2024}
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
|
|||||||
# Install apt dependencies
|
# Install apt dependencies
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential cmake \
|
build-essential cmake \
|
||||||
|
git git-lfs openssh-client \
|
||||||
|
nano vim \
|
||||||
|
htop atop nvtop \
|
||||||
|
sed gawk grep curl wget \
|
||||||
|
tcpdump sysstat screen \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
@@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
|
|||||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||||
|
|
||||||
# Install LeRobot
|
# Install LeRobot
|
||||||
COPY . /lerobot
|
RUN git lfs install
|
||||||
|
RUN git clone https://github.com/huggingface/lerobot.git
|
||||||
WORKDIR /lerobot
|
WORKDIR /lerobot
|
||||||
RUN pip install --upgrade --no-cache-dir pip
|
RUN pip install --upgrade --no-cache-dir pip
|
||||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
import torch
|
import torch
|
||||||
@@ -21,39 +22,36 @@ import torch
|
|||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
print("List of available datasets", lerobot.available_datasets)
|
print("List of available datasets:")
|
||||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
pprint(lerobot.available_datasets)
|
||||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
|
||||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
|
||||||
|
|
||||||
|
# Let's take one for this example
|
||||||
repo_id = "lerobot/pusht"
|
repo_id = "lerobot/pusht"
|
||||||
|
|
||||||
# You can easily load a dataset from a Hugging Face repositery
|
# You can easily load a dataset from a Hugging Face repository
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
|
|
||||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||||
# TODO(rcadene): update to make the print pretty
|
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||||
print(f"{dataset=}")
|
print(dataset)
|
||||||
print(f"{dataset.hf_dataset=}")
|
print(dataset.hf_dataset)
|
||||||
|
|
||||||
# and provides additional utilities for robotics and compatibility with pytorch
|
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||||
print(f"number of episodes: {dataset.num_episodes=}")
|
|
||||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
|
||||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||||
|
|
||||||
# Access frame indexes associated to first episode
|
# Access frame indexes associated to first episode
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||||
|
|
||||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
|
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
|
||||||
# Here we grab all the image frames.
|
# with the latter, like iterating through the dataset. Here we grab all the image frames.
|
||||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||||
|
|
||||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
|
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
|
||||||
# To visualize them, we convert to uint8 range [0,255]
|
# them, we convert to uint8 in range [0,255]
|
||||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||||
# and to channel last (h,w,c).
|
# and to channel last (h,w,c).
|
||||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||||
@@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|||||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||||
|
|
||||||
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
|
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||||
# Our datasets can load previous and future frames for each key/modality,
|
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||||
# using timestamps differences with the current loaded frame. For instance:
|
# differences with the current loaded frame. For instance:
|
||||||
delta_timestamps = {
|
delta_timestamps = {
|
||||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||||
"observation.image": [-1, -0.5, -0.20, 0],
|
"observation.image": [-1, -0.5, -0.20, 0],
|
||||||
@@ -74,12 +72,12 @@ delta_timestamps = {
|
|||||||
"action": [t / dataset.fps for t in range(64)],
|
"action": [t / dataset.fps for t in range(64)],
|
||||||
}
|
}
|
||||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
|
||||||
|
|
||||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||||
# because they are just PyTorch datasets.
|
# PyTorch datasets.
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
|
|||||||
@@ -5,23 +5,108 @@ training outputs directory. In the latter case, you might want to run examples/3
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gym_pusht # noqa: F401
|
||||||
|
import gymnasium as gym
|
||||||
|
import imageio
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.scripts.eval import eval
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
# Get a pretrained policy from the hub.
|
# Create a directory to store the video of the evaluation
|
||||||
pretrained_policy_name = "lerobot/diffusion_pusht"
|
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# Download the diffusion policy for pusht environment
|
||||||
|
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
# Override some config parameters to do with evaluation.
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
overrides = [
|
policy.eval()
|
||||||
"eval.n_episodes=10",
|
policy.to(device)
|
||||||
"eval.batch_size=10",
|
|
||||||
"device=cuda",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Evaluate the policy and save the outputs including metrics and videos.
|
# Initialize evaluation environment to render two observation types:
|
||||||
# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
|
# an image of the scene and state/position of the agent. The environment
|
||||||
eval(pretrained_policy_path=pretrained_policy_path)
|
# also automatically stops running after 300 interactions/steps.
|
||||||
|
env = gym.make(
|
||||||
|
"gym_pusht/PushT-v0",
|
||||||
|
obs_type="pixels_agent_pos",
|
||||||
|
max_episode_steps=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset the policy and environmens to prepare for rollout
|
||||||
|
policy.reset()
|
||||||
|
numpy_observation, info = env.reset(seed=42)
|
||||||
|
|
||||||
|
# Prepare to collect every rewards and all the frames of the episode,
|
||||||
|
# from initial state to final state.
|
||||||
|
rewards = []
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
# Render frame of the initial state
|
||||||
|
frames.append(env.render())
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
# Prepare observation for the policy running in Pytorch
|
||||||
|
state = torch.from_numpy(numpy_observation["agent_pos"])
|
||||||
|
image = torch.from_numpy(numpy_observation["pixels"])
|
||||||
|
|
||||||
|
# Convert to float32 with image from channel first in [0,255]
|
||||||
|
# to channel last in [0,1]
|
||||||
|
state = state.to(torch.float32)
|
||||||
|
image = image.to(torch.float32) / 255
|
||||||
|
image = image.permute(2, 0, 1)
|
||||||
|
|
||||||
|
# Send data tensors from CPU to GPU
|
||||||
|
state = state.to(device, non_blocking=True)
|
||||||
|
image = image.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# Add extra (empty) batch dimension, required to forward the policy
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Create the policy input dictionary
|
||||||
|
observation = {
|
||||||
|
"observation.state": state,
|
||||||
|
"observation.image": image,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Predict the next action with respect to the current observation
|
||||||
|
with torch.inference_mode():
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# Prepare the action for the environment
|
||||||
|
numpy_action = action.squeeze(0).to("cpu").numpy()
|
||||||
|
|
||||||
|
# Step through the environment and receive a new observation
|
||||||
|
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
|
||||||
|
print(f"{step=} {reward=} {terminated=}")
|
||||||
|
|
||||||
|
# Keep track of all the rewards and frames
|
||||||
|
rewards.append(reward)
|
||||||
|
frames.append(env.render())
|
||||||
|
|
||||||
|
# The rollout is considered done when the success state is reach (i.e. terminated is True),
|
||||||
|
# or the maximum number of iterations is reached (i.e. truncated is True)
|
||||||
|
done = terminated | truncated | done
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
if terminated:
|
||||||
|
print("Success!")
|
||||||
|
else:
|
||||||
|
print("Failure!")
|
||||||
|
|
||||||
|
# Get the speed of environment (i.e. its number of frames per second).
|
||||||
|
fps = env.metadata["render_fps"]
|
||||||
|
|
||||||
|
# Encode all frames into a mp4 video.
|
||||||
|
video_path = output_directory / "rollout.mp4"
|
||||||
|
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
|
||||||
|
|
||||||
|
print(f"Video of the evaluation is available in '{video_path}'.")
|
||||||
|
|||||||
@@ -4,36 +4,42 @@ Once you have trained a model with this script, you can try to evaluate it on
|
|||||||
examples/2_evaluate_pretrained_policy.py
|
examples/2_evaluate_pretrained_policy.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
|
|
||||||
|
# Create a directory to store the training checkpoint.
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Number of offline training steps (we'll only do offline training for this example.
|
# Number of offline training steps (we'll only do offline training for this example.)
|
||||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||||
training_steps = 5000
|
training_steps = 5000
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
log_freq = 250
|
log_freq = 250
|
||||||
|
|
||||||
# Set up the dataset.
|
# Set up the dataset.
|
||||||
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
|
delta_timestamps = {
|
||||||
dataset = make_dataset(hydra_cfg)
|
# Load the previous image and state at -0.1 seconds before current frame,
|
||||||
|
# then load current image and state corresponding to 0.0 second.
|
||||||
|
"observation.image": [-0.1, 0.0],
|
||||||
|
"observation.state": [-0.1, 0.0],
|
||||||
|
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||||
|
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||||
|
# used to supervise the policy.
|
||||||
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
|
}
|
||||||
|
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
# Set up the the policy.
|
# Set up the the policy.
|
||||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
|
||||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||||
policy.train()
|
policy.train()
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
@@ -69,7 +75,5 @@ while not done:
|
|||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save the policy.
|
# Save a policy checkpoint.
|
||||||
policy.save_pretrained(output_directory)
|
policy.save_pretrained(output_directory)
|
||||||
# Save the Hydra configuration so we have the environment configuration for eval.
|
|
||||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||||
@@ -85,13 +100,6 @@ available_datasets = list(
|
|||||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
|
|
||||||
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
|
||||||
|
|
||||||
available_datasets = list(
|
|
||||||
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
|
||||||
)
|
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""To enable `lerobot.__version__`"""
|
"""To enable `lerobot.__version__`"""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|||||||
@@ -37,16 +37,16 @@ How to decode videos?
|
|||||||
## Variables
|
## Variables
|
||||||
|
|
||||||
**Image content**
|
**Image content**
|
||||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
|
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
|
||||||
|
|
||||||
**Requested timestamps**
|
**Requested timestamps**
|
||||||
In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
|
In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
|
||||||
- `single_frame`: 1 frame,
|
- `single_frame`: 1 frame,
|
||||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||||
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
|
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
|
||||||
|
|
||||||
**Data augmentations**
|
**Data augmentations**
|
||||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
|
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||||
|
|
||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -72,6 +87,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
|
"""Frames per second used during data collection."""
|
||||||
return self.info["fps"]
|
return self.info["fps"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -86,15 +102,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return self.hf_dataset.features
|
return self.hf_dataset.features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
image_keys = []
|
"""Keys to access image and video stream from cameras."""
|
||||||
|
keys = []
|
||||||
for key, feats in self.hf_dataset.features.items():
|
for key, feats in self.hf_dataset.features.items():
|
||||||
if isinstance(feats, datasets.Image):
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||||
image_keys.append(key)
|
keys.append(key)
|
||||||
return image_keys + self.video_frame_keys
|
return keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_frame_keys(self):
|
def video_frame_keys(self) -> list[str]:
|
||||||
|
"""Keys to access video frames that requires to be decoded into images.
|
||||||
|
|
||||||
|
Note: It is empty if the dataset contains images only,
|
||||||
|
or equal to `self.cameras` if the dataset contains videos only,
|
||||||
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||||
|
"""
|
||||||
video_frame_keys = []
|
video_frame_keys = []
|
||||||
for key, feats in self.hf_dataset.features.items():
|
for key, feats in self.hf_dataset.features.items():
|
||||||
if isinstance(feats, VideoFrame):
|
if isinstance(feats, VideoFrame):
|
||||||
@@ -103,10 +126,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
"""Number of possible samples in the dataset.
|
||||||
|
|
||||||
|
This is equivalent to the number of frames in the dataset minus n_end_keyframes_dropped.
|
||||||
|
"""
|
||||||
return len(self.index)
|
return len(self.index)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
"""Number of episodes."""
|
||||||
return len(self.hf_dataset.unique("episode_index"))
|
return len(self.hf_dataset.unique("episode_index"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -146,6 +174,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}(\n"
|
||||||
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
|
f" Version: '{self.version}',\n"
|
||||||
|
f" Split: '{self.split}',\n"
|
||||||
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||||
|
f" Recorded Frames per Second: {self.fps},\n"
|
||||||
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
|
f" Transformations: {self.transform},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_preloaded(
|
def from_preloaded(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||||
|
|
||||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||||
useless dependencies when using datasets.
|
useless dependencies when using datasets.
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# imagecodecs/numcodecs.py
|
# imagecodecs/numcodecs.py
|
||||||
|
|
||||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||||
"""
|
"""
|
||||||
@@ -142,12 +157,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
features = {}
|
features = {}
|
||||||
|
|
||||||
image_keys = [key for key in data_dict if "observation.images." in key]
|
keys = [key for key in data_dict if "observation.images." in key]
|
||||||
for image_key in image_keys:
|
for key in keys:
|
||||||
if video:
|
if video:
|
||||||
features[image_key] = VideoFrame()
|
features[key] = VideoFrame()
|
||||||
else:
|
else:
|
||||||
features[image_key] = Image()
|
features[key] = Image()
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
features["observation.state"] = Sequence(
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# TODO(rcadene, alexander-soare): clean this file
|
# TODO(rcadene, alexander-soare): clean this file
|
||||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||||
|
|
||||||
@@ -17,9 +32,10 @@ def log_output_dir(out_dir):
|
|||||||
|
|
||||||
|
|
||||||
def cfg_to_group(cfg, return_list=False):
|
def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
"""Return a group name for logging. Optionally returns group name as list."""
|
||||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
|
||||||
lst = [
|
lst = [
|
||||||
|
f"policy:{cfg.policy.name}",
|
||||||
|
f"dataset:{cfg.dataset_repo_id}",
|
||||||
f"env:{cfg.env.name}",
|
f"env:{cfg.env.name}",
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
@@ -81,9 +97,9 @@ class Logger:
|
|||||||
# Also save the full Hydra config for the env configuration.
|
# Also save the full Hydra config for the env configuration.
|
||||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" or "/" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||||
type="model",
|
type="model",
|
||||||
)
|
)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
@@ -93,9 +109,10 @@ class Logger:
|
|||||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||||
buffer.save(fp)
|
buffer.save(fp)
|
||||||
if self._wandb:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
|
# note wandb artifact does not accept ":" or "/" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||||
type="buffer",
|
type="buffer",
|
||||||
)
|
)
|
||||||
artifact.add_file(fp)
|
artifact.add_file(fp)
|
||||||
@@ -113,6 +130,11 @@ class Logger:
|
|||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
if self._wandb is not None:
|
if self._wandb is not None:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
if not isinstance(v, (int, float, str)):
|
||||||
|
logging.warning(
|
||||||
|
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@@ -51,8 +66,12 @@ class ACTConfig:
|
|||||||
documentation in the policy class).
|
documentation in the policy class).
|
||||||
latent_dim: The VAE's latent dimension.
|
latent_dim: The VAE's latent dimension.
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||||
environment step.
|
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||||
|
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||||
|
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||||
|
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||||
|
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
@@ -100,6 +119,9 @@ class ACTConfig:
|
|||||||
dim_feedforward: int = 3200
|
dim_feedforward: int = 3200
|
||||||
feedforward_activation: str = "relu"
|
feedforward_activation: str = "relu"
|
||||||
n_encoder_layers: int = 4
|
n_encoder_layers: int = 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
n_decoder_layers: int = 1
|
n_decoder_layers: int = 1
|
||||||
# VAE.
|
# VAE.
|
||||||
use_vae: bool = True
|
use_vae: bool = True
|
||||||
@@ -107,7 +129,7 @@ class ACTConfig:
|
|||||||
n_vae_encoder_layers: int = 4
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: bool = False
|
temporal_ensemble_momentum: float | None = None
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
@@ -119,8 +141,11 @@ class ACTConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
if self.use_temporal_aggregation:
|
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
raise NotImplementedError(
|
||||||
|
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||||
|
"because the policy needs to be queried every step to compute the ensembled action."
|
||||||
|
)
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
@@ -130,10 +155,3 @@ class ACTConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
# Check that there is only one image.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
|
||||||
if (
|
|
||||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
|
||||||
or "observation.images.top" not in self.input_shapes
|
|
||||||
):
|
|
||||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Action Chunking Transformer Policy
|
"""Action Chunking Transformer Policy
|
||||||
|
|
||||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||||
@@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if config is None:
|
if config is None:
|
||||||
config = ACTConfig()
|
config = ACTConfig()
|
||||||
self.config = config
|
self.config: ACTConfig = config
|
||||||
|
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
@@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = ACT(config)
|
self.model = ACT(config)
|
||||||
|
|
||||||
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.n_action_steps is not None:
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
self._ensembled_actions = None
|
||||||
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
@@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
assert "observation.images.top" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
self._stack_images(batch)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
|
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||||
|
# the first action.
|
||||||
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
if self._ensembled_actions is None:
|
||||||
|
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||||
|
# time step of the episode.
|
||||||
|
self._ensembled_actions = actions.clone()
|
||||||
|
else:
|
||||||
|
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
|
# the EMA update for those entries.
|
||||||
|
alpha = self.config.temporal_ensemble_momentum
|
||||||
|
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||||
|
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||||
|
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
|
# "Consume" the first action.
|
||||||
|
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
# querying the policy.
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
self._stack_images(batch)
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
@@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
mean_kld = (
|
mean_kld = (
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss_dict["kld_loss"] = mean_kld
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Stack images in the order dictated by input_shapes.
|
|
||||||
batch["observation.images"] = torch.stack(
|
|
||||||
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
|
|
||||||
dim=-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
@@ -161,10 +188,10 @@ class ACT(nn.Module):
|
|||||||
│ encoder │ │ │ │Transf.│ │
|
│ encoder │ │ │ │Transf.│ │
|
||||||
│ │ │ │ │encoder│ │
|
│ │ │ │ │encoder│ │
|
||||||
└───▲─────┘ │ │ │ │ │
|
└───▲─────┘ │ │ │ │ │
|
||||||
│ │ │ └───▲───┘ │
|
│ │ │ └▲──▲─▲─┘ │
|
||||||
│ │ │ │ │
|
│ │ │ │ │ │ │
|
||||||
inputs └─────┼─────┘ │
|
inputs └─────┼──┘ │ image emb. │
|
||||||
│ │
|
│ state emb. │
|
||||||
└───────────────────────┘
|
└───────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -306,18 +333,18 @@ class ACT(nn.Module):
|
|||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||||
encoder_in.flatten(2).permute(2, 0, 1),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pos_embed = torch.cat(
|
pos_embed = torch.cat(
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +67,7 @@ class DiffusionConfig:
|
|||||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||||
modulation.
|
modulation.
|
||||||
|
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||||
beta_start: Beta value for the first forward-diffusion step.
|
beta_start: Beta value for the first forward-diffusion step.
|
||||||
@@ -64,6 +81,9 @@ class DiffusionConfig:
|
|||||||
clip_sample_range: The magnitude of the clipping range as described above.
|
clip_sample_range: The magnitude of the clipping range as described above.
|
||||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||||
|
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||||
|
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||||
|
to False as the original Diffusion Policy implementation does the same.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
@@ -107,6 +127,7 @@ class DiffusionConfig:
|
|||||||
diffusion_step_embed_dim: int = 128
|
diffusion_step_embed_dim: int = 128
|
||||||
use_film_scale_modulation: bool = True
|
use_film_scale_modulation: bool = True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: str = "DDPM"
|
||||||
num_train_timesteps: int = 100
|
num_train_timesteps: int = 100
|
||||||
beta_schedule: str = "squaredcos_cap_v2"
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
beta_start: float = 0.0001
|
beta_start: float = 0.0001
|
||||||
@@ -118,23 +139,39 @@ class DiffusionConfig:
|
|||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: int | None = None
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
|
# Loss computation
|
||||||
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
if (
|
if (
|
||||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||||
'`input_shapes["observation.image"]`.'
|
"`input_shapes[{image_key}]`."
|
||||||
)
|
)
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||||
)
|
)
|
||||||
|
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||||
|
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||||
|
raise ValueError(
|
||||||
|
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||||
|
f"Got {self.noise_scheduler_type}."
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,8 +1,24 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
|
- Make compatible with multiple image keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -10,12 +26,13 @@ from collections import deque
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
@@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
|
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
|
||||||
"""
|
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
@@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
@@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
|
|||||||
* config.n_obs_steps,
|
* config.n_obs_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
|
config.noise_scheduler_type,
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
beta_start=config.beta_start,
|
beta_start=config.beta_start,
|
||||||
beta_end=config.beta_end,
|
beta_end=config.beta_end,
|
||||||
beta_schedule=config.beta_schedule,
|
beta_schedule=config.beta_schedule,
|
||||||
variance_type="fixed_small",
|
|
||||||
clip_sample=config.clip_sample,
|
clip_sample=config.clip_sample,
|
||||||
clip_sample_range=config.clip_sample_range,
|
clip_sample_range=config.clip_sample_range,
|
||||||
prediction_type=config.prediction_type,
|
prediction_type=config.prediction_type,
|
||||||
@@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
@@ -268,13 +304,84 @@ class DiffusionModel(nn.Module):
|
|||||||
loss = F.mse_loss(pred, target, reduction="none")
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
|
||||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||||
if "action_is_pad" in batch:
|
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||||
|
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||||
|
|
||||||
|
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||||
|
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||||
|
|
||||||
|
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||||
|
-----------------------------------------------------
|
||||||
|
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||||
|
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||||
|
| ... | ... | ... | ... |
|
||||||
|
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||||
|
-----------------------------------------------------
|
||||||
|
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||||
|
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||||
|
|
||||||
|
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||||
|
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||||
|
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, num_kp=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (list): (C, H, W) input feature map shape.
|
||||||
|
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3
|
||||||
|
self._in_c, self._in_h, self._in_w = input_shape
|
||||||
|
|
||||||
|
if num_kp is not None:
|
||||||
|
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||||
|
self._out_c = num_kp
|
||||||
|
else:
|
||||||
|
self.nets = None
|
||||||
|
self._out_c = self._in_c
|
||||||
|
|
||||||
|
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||||
|
# and causes a small degradation in pc_success of pre-trained models.
|
||||||
|
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||||
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
# register as buffer so it's moved to the correct device.
|
||||||
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (B, C, H, W) input feature maps.
|
||||||
|
Returns:
|
||||||
|
(B, K, 2) image-space coordinates of keypoints.
|
||||||
|
"""
|
||||||
|
if self.nets is not None:
|
||||||
|
features = self.nets(features)
|
||||||
|
|
||||||
|
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||||
|
features = features.reshape(-1, self._in_h * self._in_w)
|
||||||
|
# 2d softmax normalization
|
||||||
|
attention = F.softmax(features, dim=-1)
|
||||||
|
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||||
|
expected_xy = attention @ self.pos_grid
|
||||||
|
# reshape to [B, K, 2]
|
||||||
|
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||||
|
|
||||||
|
return feature_keypoints
|
||||||
|
|
||||||
|
|
||||||
class DiffusionRgbEncoder(nn.Module):
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
@@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
|
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
|
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||||
|
# use the height and width from `config.crop_shape`.
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
assert len(image_keys) == 1
|
||||||
|
image_key = image_keys[0]
|
||||||
|
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
feat_map_shape = tuple(
|
dummy_feature_map = self.backbone(dummy_input)
|
||||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||||
)
|
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
|
||||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|||||||
@@ -1,4 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
@@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
|
|||||||
|
|
||||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||||
assert set(hydra_cfg.policy).issuperset(
|
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||||
expected_kwargs
|
logging.warning(
|
||||||
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||||
|
)
|
||||||
policy_cfg = policy_cfg_class(
|
policy_cfg = policy_cfg_class(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v
|
||||||
@@ -62,11 +79,18 @@ def make_policy(
|
|||||||
|
|
||||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||||
|
|
||||||
|
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||||
if pretrained_policy_name_or_path is None:
|
if pretrained_policy_name_or_path is None:
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
# Make a fresh policy.
|
||||||
policy = policy_cls(policy_cfg, dataset_stats)
|
policy = policy_cls(policy_cfg, dataset_stats)
|
||||||
else:
|
else:
|
||||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
|
# hyperparameters that we want to vary).
|
||||||
|
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
||||||
|
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
||||||
|
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||||
|
policy = policy_cls(policy_cfg)
|
||||||
|
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||||
|
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""A protocol that all policies should follow.
|
"""A protocol that all policies should follow.
|
||||||
|
|
||||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||||
@@ -38,7 +53,8 @@ class Policy(Protocol):
|
|||||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||||
"""Run the batch through the model and compute the loss for training or validation.
|
"""Run the batch through the model and compute the loss for training or validation.
|
||||||
|
|
||||||
Returns a dictionary with "loss" and maybe other information.
|
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||||
|
other items should be logging-friendly, native Python types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +63,7 @@ class TDMPCConfig:
|
|||||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||||
elites, when updating the gaussian parameters for CEM.
|
elites, when updating the gaussian parameters for CEM.
|
||||||
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||||
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||||
@@ -131,12 +147,18 @@ class TDMPCConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
|
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||||
# augmentation. It should be able to be removed.
|
# augmentation. It should be able to be removed.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only square images are handled now. Got image shape "
|
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||||
f"{self.input_shapes['observation.image']}."
|
|
||||||
)
|
)
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Implementation of Finetuning Offline World Models in the Real World.
|
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||||
|
|
||||||
The comments in this code may sometimes refer to these references:
|
The comments in this code may sometimes refer to these references:
|
||||||
@@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, fp):
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
"""Save state dict of TOLD model to filepath."""
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
torch.save(self.state_dict(), fp)
|
assert len(image_keys) == 1
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
def load(self, fp):
|
self.reset()
|
||||||
"""Load a saved state dict from filepath into current agent."""
|
|
||||||
self.load_state_dict(torch.load(fp))
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
@@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
|
|
||||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
|
||||||
batch_size = batch["index"].shape[0]
|
|
||||||
|
|
||||||
# (b, t) -> (t, b)
|
# (b, t) -> (t, b)
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if batch[key].ndim > 1:
|
if batch[key].ndim > 1:
|
||||||
@@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
# Run latent rollout using the latent dynamics model and policy model.
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
# gives us a next `z`.
|
# gives us a next `z`.
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
z_preds[0] = self.model.encode(current_observation)
|
z_preds[0] = self.model.encode(current_observation)
|
||||||
reward_preds = torch.empty_like(reward, device=device)
|
reward_preds = torch.empty_like(reward, device=device)
|
||||||
|
|||||||
@@ -1,9 +1,28 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def populate_queues(queues, batch):
|
def populate_queues(queues, batch):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||||
|
# queues have the keys they want).
|
||||||
|
if key not in queues:
|
||||||
|
continue
|
||||||
if len(queues[key]) != queues[key].maxlen:
|
if len(queues[key]) != queues[key].maxlen:
|
||||||
# initialize by copying the first observation several times until the queue is full
|
# initialize by copying the first observation several times until the queue is full
|
||||||
while len(queues[key]) != queues[key].maxlen:
|
while len(queues[key]) != queues[key].maxlen:
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
|
|||||||
@@ -1,8 +1,25 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -39,6 +56,31 @@ def set_global_seed(seed):
|
|||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
|
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
a = random.random() # produces some random number
|
||||||
|
with seeded_context(1337):
|
||||||
|
b = random.random() # produces some other random number
|
||||||
|
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
random_state = random.getstate()
|
||||||
|
np_random_state = np.random.get_state()
|
||||||
|
torch_random_state = torch.random.get_rng_state()
|
||||||
|
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||||
|
set_global_seed(seed)
|
||||||
|
yield None
|
||||||
|
random.setstate(random_state)
|
||||||
|
np.random.set_state(np_random_state)
|
||||||
|
torch.random.set_rng_state(torch_random_state)
|
||||||
|
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||||
|
|
||||||
|
|
||||||
def init_logging():
|
def init_logging():
|
||||||
def custom_format(record):
|
def custom_format(record):
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ training:
|
|||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: false
|
save_model: true
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
@@ -35,7 +35,7 @@ eval:
|
|||||||
use_async_envs: false
|
use_async_envs: false
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: true
|
enable: false
|
||||||
# Set to true to disable saving an artifact despite save_model == True
|
# Set to true to disable saving an artifact despite save_model == True
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
project: lerobot
|
project: lerobot
|
||||||
|
|||||||
@@ -3,6 +3,12 @@
|
|||||||
seed: 1000
|
seed: 1000
|
||||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.top:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 80000
|
offline_steps: 80000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
@@ -18,12 +24,6 @@ training:
|
|||||||
grad_clip_norm: 10
|
grad_clip_norm: 10
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
override_dataset_stats:
|
|
||||||
observation.images.top:
|
|
||||||
# stats from imagenet, since we use a pretrained vision model
|
|
||||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
|
||||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
@@ -66,6 +66,9 @@ policy:
|
|||||||
dim_feedforward: 3200
|
dim_feedforward: 3200
|
||||||
feedforward_activation: relu
|
feedforward_activation: relu
|
||||||
n_encoder_layers: 4
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
n_decoder_layers: 1
|
n_decoder_layers: 1
|
||||||
# VAE.
|
# VAE.
|
||||||
use_vae: true
|
use_vae: true
|
||||||
@@ -73,7 +76,7 @@ policy:
|
|||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: false
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|||||||
@@ -7,6 +7,20 @@
|
|||||||
seed: 100000
|
seed: 100000
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||||
|
observation.image:
|
||||||
|
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||||
|
# from the original codebase, but we should remove these and train our own pretrained model
|
||||||
|
observation.state:
|
||||||
|
min: [13.456424, 32.938293]
|
||||||
|
max: [496.14618, 510.9579]
|
||||||
|
action:
|
||||||
|
min: [12.0, 25.0]
|
||||||
|
max: [511.0, 511.0]
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 200000
|
offline_steps: 200000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
@@ -44,20 +58,6 @@ eval:
|
|||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
||||||
override_dataset_stats:
|
|
||||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
|
||||||
observation.image:
|
|
||||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
|
||||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
|
||||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
|
||||||
# from the original codebase, but we should remove these and train our own pretrained model
|
|
||||||
observation.state:
|
|
||||||
min: [13.456424, 32.938293]
|
|
||||||
max: [496.14618, 510.9579]
|
|
||||||
action:
|
|
||||||
min: [12.0, 25.0]
|
|
||||||
max: [511.0, 511.0]
|
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: diffusion
|
name: diffusion
|
||||||
|
|
||||||
@@ -95,6 +95,7 @@ policy:
|
|||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 128
|
||||||
use_film_scale_modulation: True
|
use_film_scale_modulation: True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: DDPM
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
beta_schedule: squaredcos_cap_v2
|
beta_schedule: squaredcos_cap_v2
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
@@ -105,3 +106,6 @@ policy:
|
|||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
|
|
||||||
|
# Loss computation
|
||||||
|
do_mask_loss_for_padding: false
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
dataset_repo_id: lerobot/xarm_lift_medium
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 25000
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||||
|
|
||||||
Usage examples:
|
Usage examples:
|
||||||
@@ -583,17 +598,18 @@ if __name__ == "__main__":
|
|||||||
pretrained_policy_path = Path(
|
pretrained_policy_path = Path(
|
||||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||||
)
|
)
|
||||||
except HFValidationError:
|
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||||
logging.warning(
|
if isinstance(e, HFValidationError):
|
||||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
|
error_message = (
|
||||||
"Treating it as a local directory."
|
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||||
)
|
)
|
||||||
except RepositoryNotFoundError:
|
else:
|
||||||
logging.warning(
|
error_message = (
|
||||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
|
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||||
"it as a local directory."
|
)
|
||||||
)
|
|
||||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||||
|
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||||
@@ -60,7 +75,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
@@ -252,7 +267,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--revision",
|
||||||
type=str,
|
type=str,
|
||||||
default="v1.2",
|
default=CODEBASE_VERSION,
|
||||||
help="Codebase version used to generate the dataset.",
|
help="Codebase version used to generate the dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -8,7 +23,7 @@ import hydra
|
|||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
from diffusers.optimization import get_scheduler
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
@@ -55,6 +70,8 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
cfg.training.adam_weight_decay,
|
cfg.training.adam_weight_decay,
|
||||||
)
|
)
|
||||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
cfg.training.lr_scheduler,
|
cfg.training.lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
@@ -71,6 +88,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
|
"""Returns a dictionary of items for logging."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.train()
|
policy.train()
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
@@ -98,6 +116,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
|||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
@@ -121,7 +140,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
|||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
@@ -289,7 +308,7 @@ def add_episodes_inplace(
|
|||||||
sampler.num_samples = len(concat_dataset)
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
@@ -336,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# Note: this helper will be used in offline and online training loops.
|
# Note: this helper will be used in offline and online training loops.
|
||||||
def _maybe_eval_and_maybe_save(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
@@ -392,9 +411,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# step + 1.
|
# so we pass in step + 1.
|
||||||
_maybe_eval_and_maybe_save(step + 1)
|
evaluate_and_checkpoint_if_needed(step + 1)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
@@ -460,9 +479,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# in step + 1.
|
# so we pass in step + 1.
|
||||||
_maybe_eval_and_maybe_save(step + 1)
|
evaluate_and_checkpoint_if_needed(step + 1)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
online_step += 1
|
online_step += 1
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||||
|
|
||||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||||
@@ -32,7 +47,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
|||||||
```
|
```
|
||||||
|
|
||||||
- Visualize data stored on a distant machine through streaming:
|
- Visualize data stored on a distant machine through streaming:
|
||||||
(You need to forward the websocket port to the distant machine, with
|
(You need to forward the websocket port to the distant machine, with
|
||||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||||
```
|
```
|
||||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||||
@@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -115,15 +131,17 @@ def visualize_dataset(
|
|||||||
|
|
||||||
spawn_local_viewer = mode == "local" and not save
|
spawn_local_viewer = mode == "local" and not save
|
||||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||||
|
|
||||||
|
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||||
|
# when iterating on a dataloader with `num_workers` > 0
|
||||||
|
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
if mode == "distant":
|
if mode == "distant":
|
||||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||||
|
|
||||||
logging.info("Logging to Rerun")
|
logging.info("Logging to Rerun")
|
||||||
|
|
||||||
if num_workers > 0:
|
|
||||||
# TODO(rcadene): fix data workers hanging when `rr.init` is called
|
|
||||||
logging.warning("If data loader is hanging, try `--num-workers 0`.")
|
|
||||||
|
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
# iterate over the batch
|
# iterate over the batch
|
||||||
for i in range(len(batch["index"])):
|
for i in range(len(batch["index"])):
|
||||||
@@ -131,7 +149,7 @@ def visualize_dataset(
|
|||||||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||||
|
|
||||||
# display each camera image
|
# display each camera image
|
||||||
for key in dataset.image_keys:
|
for key in dataset.camera_keys:
|
||||||
# TODO(rcadene): add `.compress()`? is it lossless?
|
# TODO(rcadene): add `.compress()`? is it lossless?
|
||||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||||
|
|
||||||
@@ -196,7 +214,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=4,
|
||||||
help="Number of processes of Dataloader for loading the data.",
|
help="Number of processes of Dataloader for loading the data.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
BIN
media/wandb.png
Normal file
BIN
media/wandb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 407 KiB |
774
poetry.lock
generated
774
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,9 @@ version = "0.1.0"
|
|||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
authors = [
|
authors = [
|
||||||
"Rémi Cadène <re.cadene@gmail.com>",
|
"Rémi Cadène <re.cadene@gmail.com>",
|
||||||
|
"Simon Alibert <alibert.sim@gmail.com>",
|
||||||
"Alexander Soare <alexander.soare159@gmail.com>",
|
"Alexander Soare <alexander.soare159@gmail.com>",
|
||||||
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
||||||
"Simon Alibert <alibert.sim@gmail.com>",
|
|
||||||
"Adil Zouitine <adilzouitinegm@gmail.com>",
|
"Adil Zouitine <adilzouitinegm@gmail.com>",
|
||||||
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
||||||
]
|
]
|
||||||
@@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<3.13"
|
python = ">=3.10,<3.13"
|
||||||
termcolor = "^2.4.0"
|
termcolor = ">=2.4.0"
|
||||||
omegaconf = "^2.3.0"
|
omegaconf = ">=2.3.0"
|
||||||
wandb = "^0.16.3"
|
wandb = ">=0.16.3"
|
||||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||||
gdown = "^5.1.0"
|
gdown = ">=5.1.0"
|
||||||
hydra-core = "^1.3.2"
|
hydra-core = ">=1.3.2"
|
||||||
einops = "^0.8.0"
|
einops = ">=0.8.0"
|
||||||
pymunk = "^6.6.0"
|
pymunk = ">=6.6.0"
|
||||||
zarr = "^2.17.0"
|
zarr = ">=2.17.0"
|
||||||
numba = "^0.59.0"
|
numba = ">=0.59.0"
|
||||||
torch = "^2.2.1"
|
torch = "^2.2.1"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = ">=4.9.0"
|
||||||
diffusers = "^0.27.2"
|
diffusers = "^0.27.2"
|
||||||
torchvision = "^0.18.0"
|
torchvision = ">=0.18.0"
|
||||||
h5py = "^3.10.0"
|
h5py = ">=3.10.0"
|
||||||
huggingface-hub = "^0.21.4"
|
huggingface-hub = ">=0.21.4"
|
||||||
robomimic = "0.2.0"
|
gymnasium = ">=0.29.1"
|
||||||
gymnasium = "^0.29.1"
|
cmake = ">=3.29.0.1"
|
||||||
cmake = "^3.29.0.1"
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
gym-pusht = { version = "^0.1.0", optional = true}
|
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||||
gym-xarm = { version = "^0.1.0", optional = true}
|
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||||
gym-aloha = { version = "^0.1.0", optional = true}
|
pre-commit = {version = ">=3.7.0", optional = true}
|
||||||
pre-commit = {version = "^3.7.0", optional = true}
|
debugpy = {version = ">=1.8.1", optional = true}
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
pytest = {version = ">=8.1.0", optional = true}
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
|
||||||
datasets = "^2.19.0"
|
datasets = "^2.19.0"
|
||||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||||
pyav = "^12.0.5"
|
pyav = ">=12.0.5"
|
||||||
moviepy = "^1.0.3"
|
moviepy = ">=1.0.3"
|
||||||
rerun-sdk = "^0.15.1"
|
rerun-sdk = ">=0.15.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
@@ -104,5 +103,5 @@ ignore-init-module-imports = true
|
|||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.5.0"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from .utils import DEVICE
|
from .utils import DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
||||||
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@@ -15,7 +30,7 @@ from tests.utils import require_env
|
|||||||
def test_available_env_task(env_name: str, task_name: list):
|
def test_available_env_task(env_name: str, task_name: list):
|
||||||
"""
|
"""
|
||||||
This test verifies that all environments listed in `lerobot/__init__.py` can
|
This test verifies that all environments listed in `lerobot/__init__.py` can
|
||||||
be sucessfully imported — if they're installed — and that their
|
be successfully imported — if they're installed — and that their
|
||||||
`available_tasks_per_env` are valid.
|
`available_tasks_per_env` are valid.
|
||||||
"""
|
"""
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -41,7 +56,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
)
|
)
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
delta_timestamps = dataset.delta_timestamps
|
delta_timestamps = dataset.delta_timestamps
|
||||||
image_keys = dataset.image_keys
|
camera_keys = dataset.camera_keys
|
||||||
|
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
@@ -71,7 +86,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
else:
|
else:
|
||||||
assert item[key].ndim == ndim, f"{key}"
|
assert item[key].ndim == ndim, f"{key}"
|
||||||
|
|
||||||
if key in image_keys:
|
if key in camera_keys:
|
||||||
assert item[key].dtype == torch.float32, f"{key}"
|
assert item[key].dtype == torch.float32, f"{key}"
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||||
assert item[key].max() <= 1.0, f"{key}"
|
assert item[key].max() <= 1.0, f"{key}"
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|||||||
@@ -1,8 +1,25 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# TODO(aliberts): Mute logging for these tests
|
# TODO(aliberts): Mute logging for these tests
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tests.utils import require_package
|
||||||
|
|
||||||
|
|
||||||
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
|
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
|
||||||
for f, r in finds_and_replaces:
|
for f, r in finds_and_replaces:
|
||||||
@@ -21,6 +38,7 @@ def test_example_1():
|
|||||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("gym_pusht")
|
||||||
def test_examples_3_and_2():
|
def test_examples_3_and_2():
|
||||||
"""
|
"""
|
||||||
Train a model with example 3, check the outputs.
|
Train a model with example 3, check the outputs.
|
||||||
@@ -46,7 +64,7 @@ def test_examples_3_and_2():
|
|||||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||||
exec(file_contents, {})
|
exec(file_contents, {})
|
||||||
|
|
||||||
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
|
for file_name in ["model.safetensors", "config.json"]:
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
||||||
path = "examples/2_evaluate_pretrained_policy.py"
|
path = "examples/2_evaluate_pretrained_policy.py"
|
||||||
@@ -58,16 +76,16 @@ def test_examples_3_and_2():
|
|||||||
file_contents = _find_and_replace(
|
file_contents = _find_and_replace(
|
||||||
file_contents,
|
file_contents,
|
||||||
[
|
[
|
||||||
('pretrained_policy_name = "lerobot/diffusion_pusht"', ""),
|
('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
|
||||||
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
|
|
||||||
(
|
(
|
||||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||||
),
|
),
|
||||||
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||||
('"eval.batch_size=10"', '"eval.batch_size=1"'),
|
("step += 1", "break"),
|
||||||
('"device=cuda"', '"device=cpu"'),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Path("outputs/train/example_pusht_diffusion").exists()
|
exec(file_contents, {})
|
||||||
|
|
||||||
|
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -49,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||||||
"act",
|
"act",
|
||||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||||
),
|
),
|
||||||
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
|
(
|
||||||
|
"aloha",
|
||||||
|
"diffusion",
|
||||||
|
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||||
|
),
|
||||||
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
@@ -72,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
+ extra_overrides,
|
+ extra_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Additional config override logic.
|
||||||
|
if env_name == "aloha" and policy_name == "diffusion":
|
||||||
|
for keys in [
|
||||||
|
("training", "delta_timestamps"),
|
||||||
|
("policy", "input_shapes"),
|
||||||
|
("policy", "input_normalization_modes"),
|
||||||
|
]:
|
||||||
|
dct = dict(cfg[keys[0]][keys[1]])
|
||||||
|
dct["observation.images.top"] = dct["observation.image"]
|
||||||
|
del dct["observation.image"]
|
||||||
|
cfg[keys[0]][keys[1]] = dct
|
||||||
|
cfg.override_dataset_stats = None
|
||||||
|
|
||||||
|
# Additional config override logic.
|
||||||
|
if env_name == "pusht" and policy_name == "act":
|
||||||
|
for keys in [
|
||||||
|
("policy", "input_shapes"),
|
||||||
|
("policy", "input_normalization_modes"),
|
||||||
|
]:
|
||||||
|
dct = dict(cfg[keys[0]][keys[1]])
|
||||||
|
dct["observation.image"] = dct["observation.images.top"]
|
||||||
|
del dct["observation.images.top"]
|
||||||
|
cfg[keys[0]][keys[1]] = dct
|
||||||
|
cfg.override_dataset_stats = None
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
@@ -236,7 +284,7 @@ def test_normalize(insert_temporal_dim):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name, policy_name, extra_overrides",
|
"env_name, policy_name, extra_overrides",
|
||||||
[
|
[
|
||||||
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
|
("xarm", "tdmpc", []),
|
||||||
(
|
(
|
||||||
"pusht",
|
"pusht",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
|
|||||||
38
tests/test_utils.py
Normal file
38
tests/test_utils.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import random
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"rand_fn",
|
||||||
|
[
|
||||||
|
random.random,
|
||||||
|
np.random.random,
|
||||||
|
lambda: torch.rand(1).item(),
|
||||||
|
]
|
||||||
|
+ [lambda: torch.rand(1, device="cuda")]
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else [],
|
||||||
|
)
|
||||||
|
def test_seeding(rand_fn: Callable[[], int]):
|
||||||
|
set_global_seed(0)
|
||||||
|
a = rand_fn()
|
||||||
|
with seeded_context(1337):
|
||||||
|
c = rand_fn()
|
||||||
|
b = rand_fn()
|
||||||
|
set_global_seed(0)
|
||||||
|
a_ = rand_fn()
|
||||||
|
b_ = rand_fn()
|
||||||
|
# Check that `set_global_seed` lets us reproduce a and b.
|
||||||
|
assert a_ == a
|
||||||
|
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
||||||
|
assert b_ == b
|
||||||
|
set_global_seed(1337)
|
||||||
|
c_ = rand_fn()
|
||||||
|
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
||||||
|
assert c_ == c
|
||||||
@@ -1,3 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||||
|
|||||||
@@ -1,4 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import platform
|
import platform
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -61,7 +77,6 @@ def require_env(func):
|
|||||||
Decorator that skips the test if the required environment package is not installed.
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@@ -82,3 +97,20 @@ def require_env(func):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_package(package_name):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if the specified package is not installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not is_package_available(package_name):
|
||||||
|
pytest.skip(f"{package_name} not installed")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
Reference in New Issue
Block a user