Compare commits

..

18 Commits

Author SHA1 Message Date
Thomas Wolf
b670d3c43e Update lerobot/scripts/push_dataset_to_hub.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:39 +02:00
Thomas Wolf
efd3357124 Update lerobot/scripts/push_dataset_to_hub.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:33 +02:00
Thomas Wolf
785db44bd5 Update lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:21 +02:00
Thomas Wolf
e2f690e779 proposal for a more general Dora env 2024-05-29 15:29:41 +02:00
Thomas Wolf
68a680a9eb make aloha dora more flexible for A koch arm 2024-05-29 11:40:02 +02:00
Thomas Wolf
ce5329cf44 push_to_hub less hardcoded 2024-05-29 11:39:25 +02:00
Remi Cadene
f409bee6b1 WIP 2024-05-24 09:37:11 +00:00
Remi Cadene
c91ececc75 WIP 2024-05-24 09:06:29 +00:00
Simon Alibert
8a10e8442b Add boilerplate code
adding dora node logic

Add some documentation

refactor
2024-05-23 12:59:39 +00:00
Remi Cadene
10c5151fc6 remove hardcoding 2024-05-23 09:14:17 +00:00
Remi Cadene
a7d24f6dc2 works 2024-05-23 07:32:37 +00:00
Remi Cadene
642f1d0328 Add act_real_world.yaml 2024-05-22 15:15:36 +00:00
Remi Cadene
4843988d81 fix 2024-05-22 15:15:13 +00:00
Remi Cadene
772927616a fix 2024-05-22 09:12:34 +00:00
Remi Cadene
d52c6037e8 fix 2024-05-22 09:03:43 +00:00
Remi Cadene
b0cb342795 WIP Add aloha_dora_format 2024-05-21 21:16:31 +00:00
Remi Cadene
8460ea6f83 rename + format 2024-05-20 16:37:21 +00:00
haixuantao
9a3b0b738a Adding dora-record script 2024-05-20 16:34:21 +00:00
57 changed files with 1016 additions and 1510 deletions

View File

@@ -10,6 +10,7 @@ on:
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs: jobs:
latest-cpu: latest-cpu:
@@ -34,8 +35,6 @@ jobs:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
lfs: true
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@@ -52,6 +51,30 @@ jobs:
tags: huggingface/lerobot-cpu tags: huggingface/lerobot-cpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
# id: slack
# #uses: slackapi/slack-github-action@v1.25.0
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
# with:
# # Slack channel id, channel name, or user id to post message.
# # See also: https://api.slack.com/methods/chat.postMessage#channels
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
# # For posting a rich message using Block Kit
# payload: |
# {
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
# "blocks": [
# {
# "type": "section",
# "text": {
# "type": "mrkdwn",
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
# }
# }
# ]
# }
# env:
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
latest-cuda: latest-cuda:
name: GPU name: GPU
@@ -74,8 +97,6 @@ jobs:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
lfs: true
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@@ -92,40 +113,27 @@ jobs:
tags: huggingface/lerobot-gpu tags: huggingface/lerobot-gpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
latest-cuda-dev: # id: slack
name: GPU Dev # #uses: slackapi/slack-github-action@v1.25.0
runs-on: ubuntu-latest # uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
steps: # with:
- name: Cleanup disk # # Slack channel id, channel name, or user id to post message.
run: | # # See also: https://api.slack.com/methods/chat.postMessage#channels
sudo df -h # channel-id: ${{ env.CI_SLACK_CHANNEL }}
# sudo ls -l /usr/local/lib/ # # For posting a rich message using Block Kit
# sudo ls -l /usr/share/ # payload: |
sudo du -sh /usr/local/lib/ # {
sudo du -sh /usr/share/ # "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
sudo rm -rf /usr/local/lib/android # "blocks": [
sudo rm -rf /usr/share/dotnet # {
sudo du -sh /usr/local/lib/ # "type": "section",
sudo du -sh /usr/share/ # "text": {
sudo df -h # "type": "mrkdwn",
- name: Set up Docker Buildx # "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
uses: docker/setup-buildx-action@v3 # }
# }
- name: Check out code # ]
uses: actions/checkout@v4 # }
# env:
- name: Login to DockerHub # SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile
push: true
tags: huggingface/lerobot-gpu:dev
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}

View File

@@ -70,8 +70,6 @@ jobs:
# files: ./coverage.xml # files: ./coverage.xml
# verbose: true # verbose: true
- name: Tests end-to-end - name: Tests end-to-end
env:
DEVICE: cuda
run: make test-end-to-end run: make test-end-to-end
# - name: Generate Report # - name: Generate Report

31
.gitignore vendored
View File

@@ -2,16 +2,11 @@
logs logs
tmp tmp
wandb wandb
# Data
data data
outputs outputs
# Apple
.DS_Store
# VS Code
.vscode .vscode
rl
.DS_Store
# HPC # HPC
nautilus/*.yaml nautilus/*.yaml
@@ -95,7 +90,6 @@ instance/
docs/_build/ docs/_build/
# PyBuilder # PyBuilder
.pybuilder/
target/ target/
# Jupyter Notebook # Jupyter Notebook
@@ -108,6 +102,13 @@ ipython_config.py
# pyenv # pyenv
.python-version .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow # PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/ __pypackages__/
@@ -118,14 +119,6 @@ celerybeat.pid
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
# Environments
.env
.venv
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject
.spyproject .spyproject
@@ -143,9 +136,3 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

View File

@@ -10,7 +10,6 @@ endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH) export PATH := $(dir $(PYTHON_PATH)):$(PATH)
DEVICE ?= cpu
build-cpu: build-cpu:
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile . docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
@@ -19,29 +18,25 @@ build-gpu:
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile . docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
test-end-to-end: test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-act-ete-train ${MAKE} test-act-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp ${MAKE} test-diffusion-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp ${MAKE} test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train ${MAKE} test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval ${MAKE} test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train ${MAKE} test-default-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=0 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
device=$(DEVICE) \ device=cpu \
training.save_checkpoint=true \ training.save_model=true \
training.save_freq=2 \ training.save_freq=2 \
policy.n_action_steps=20 \ policy.n_action_steps=20 \
policy.chunk_size=20 \ policy.chunk_size=20 \
@@ -50,67 +45,35 @@ test-act-ete-train:
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002/pretrained_model \ -p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=8 \ env.episode_length=8 \
device=$(DEVICE) \ device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_amp/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=0 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
device=$(DEVICE) \ device=cpu \
training.save_checkpoint=true \ training.save_model=true \
training.save_freq=2 \ training.save_freq=2 \
training.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/ hydra.run.dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \ -p tests/outputs/diffusion/checkpoints/000002 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=8 \ env.episode_length=8 \
device=$(DEVICE) \ device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
@@ -119,23 +82,24 @@ test-tdmpc-ete-train:
dataset_repo_id=lerobot/xarm_lift_medium \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=0 \ training.online_steps=2 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=2 \ env.episode_length=2 \
device=$(DEVICE) \ device=cpu \
training.save_checkpoint=true \ training.save_model=true \
training.save_freq=2 \ training.save_freq=2 \
training.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/ hydra.run.dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ -p tests/outputs/tdmpc/checkpoints/000002 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=8 \ env.episode_length=8 \
device=$(DEVICE) \ device=cpu \
test-default-ete-eval: test-default-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
@@ -143,21 +107,4 @@ test-default-ete-eval:
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=8 \ env.episode_length=8 \
device=$(DEVICE) \ device=cpu \
test-act-pusht-tutorial:
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
python lerobot/scripts/train.py \
policy=created_by_Makefile.yaml \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=$(DEVICE) \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_pusht/
rm lerobot/configs/policy/created_by_Makefile.yaml

View File

@@ -77,10 +77,6 @@ Install 🤗 LeRobot:
pip install . pip install .
``` ```
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
you may need to install `cmake` and `build-essential` for building some of our dependencies.
On linux: `sudo apt-get install cmake build-essential`
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras: For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha) - [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm) - [xarm](https://github.com/huggingface/gym-xarm)
@@ -103,7 +99,6 @@ wandb login
``` ```
. .
├── examples # contains demonstration examples, start here to learn about LeRobot ├── examples # contains demonstration examples, start here to learn about LeRobot
| └── advanced # contains even more examples for those who have mastered the basics
├── lerobot ├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line | ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy | | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
@@ -154,19 +149,18 @@ python lerobot/scripts/eval.py \
``` ```
Note: After training your own policy, you can re-evaluate the checkpoints with: Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash ```bash
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model python lerobot/scripts/eval.py \
-p PATH/TO/TRAIN/OUTPUT/FOLDER
``` ```
See `python lerobot/scripts/eval.py --help` for more instructions. See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy ### Train your own policy
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task: In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
@@ -180,19 +174,6 @@ The experiment directory is automatically generated and will show up in yellow i
hydra.run.dir=your/new/experiment/dir hydra.run.dir=your/new/experiment/dir
``` ```
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
```bash
checkpoints
├── 000250 # checkpoint_dir for training step 250
│ ├── pretrained_model # Hugging Face pretrained model dir
│ │ ├── config.json # Hugging Face pretrained model config
│ │ ├── config.yaml # consolidated Hydra config
│ │ ├── model.safetensors # model weights
│ │ └── README.md # Hugging Face model card
│ └── training_state.pth # optimizer/scheduler/rng state and training step
```
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding: To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash ```bash
@@ -203,19 +184,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
![](media/wandb.png) ![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions. Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
#### Reproduce state-of-the-art (SOTA)
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
reproduces SOTA results for Diffusion Policy on the PushT task.
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
## Contribute ## Contribute
@@ -246,14 +215,14 @@ If your dataset format is not supported, implement your own in `lerobot/common/d
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)). Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain: You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. - `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility. - `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, run the following: To upload these to the hub, run the following:
```bash ```bash
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
``` ```
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy. See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.

View File

@@ -1,40 +0,0 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
# Configure image
ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
git git-lfs openssh-client \
nano vim less util-linux \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Install gh cli tool
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \
&& wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update \
&& apt install gh -y \
&& apt clean && rm -rf /var/lib/apt/lists/*
# Setup `python`
RUN ln -s /usr/bin/python3 /usr/bin/python
# Install poetry
RUN curl -sSL https://install.python-poetry.org | python -
ENV PATH="/root/.local/bin:$PATH"
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
RUN poetry config virtualenvs.create false
RUN poetry config virtualenvs.in-project true
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@@ -4,15 +4,18 @@ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies # Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \ build-essential cmake \
git git-lfs openssh-client \
nano vim ffmpeg \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment # Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv RUN python -m venv /opt/venv
@@ -20,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot # Install LeRobot
COPY . /lerobot RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"

View File

@@ -1,183 +0,0 @@
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
## The training script
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
- Makes a simulation environment.
- Makes a dataset corresponding to that simulation environment.
- Makes a policy.
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
## Basics of how we use Hydra
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
First, `lerobot/configs` has a directory structure like this:
```
.
├── default.yaml
├── env
│ ├── aloha.yaml
│ ├── pusht.yaml
│ └── xarm.yaml
└── policy
├── act.yaml
├── diffusion.yaml
└── tdmpc.yaml
```
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
When you run the training script with
```python
python lerobot/scripts/train.py
```
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
```yaml
defaults:
- _self_
- env: pusht
- policy: diffusion
```
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
```bash
python lerobot/scripts/train.py
```
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
## Overriding configuration parameters in the CLI
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
```yaml
# lerobot/configs/env/aloha.yaml
env:
task: AlohaInsertion-v0
```
And if you look in `policy/act.yaml` you will see something like:
```yaml
# lerobot/configs/policy/act.yaml
dataset_repo_id: lerobot/aloha_sim_insertion_human
```
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
```diff
# lerobot/configs/env/aloha.yaml
env:
- task: AlohaInsertion-v0
+ task: AlohaTransferCube-v0
```
Then, we'd also need to switch to using the cube transfer dataset.
```diff
# lerobot/configs/policy/act.yaml
-dataset_repo_id: lerobot/aloha_sim_insertion_human
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
```
Then, you'd be able to run:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
and you'd be training and evaluating on the cube transfer task.
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
```bash
python lerobot/scripts/train.py \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0
```
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
device=cuda
env=aloha \
env.task=AlohaTransferCube-v0 \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
policy=act \
training.eval_freq=10000 \
training.log_freq=250 \
training.offline_steps=100000 \
training.save_model=true \
training.save_freq=25000 \
eval.n_episodes=50 \
eval.batch_size=50 \
wandb.enable=false \
```
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
## Using a configuration file not in `lerobot/configs`
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
```bash
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
```
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
```bash
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
```
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
---
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
Or in the meantime, happy coding! 🤗

View File

@@ -1,37 +0,0 @@
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
## Basic training resumption
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0 \
training.log_freq=25 \
training.save_checkpoint=true \
training.save_freq=100
```
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
resume=true
```
You should see from the logging that your training picks up from where it left off.
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
---
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
Happy coding! 🤗

View File

@@ -1,87 +0,0 @@
# @package _global_
# Change the seed to match what PushT eval uses
# (to avoid evaluating on seeds used for generating the training data).
seed: 100000
# Change the dataset repository to the PushT one.
dataset_repo_id: lerobot/pusht
override_dataset_stats:
observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.image: mean_std
# Use min_max normalization just because it's more standard.
observation.state: min_max
output_normalization_modes:
# Use min_max normalization just because it's more standard.
action: min_max
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,70 +0,0 @@
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
Let's get started!
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
```diff
override_dataset_stats:
- observation.images.top:
+ observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
policy:
input_shapes:
- observation.images.top: [3, 480, 640]
+ observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
input_normalization_modes:
- observation.images.top: mean_std
+ observation.image: mean_std
observation.state: min_max
output_normalization_modes:
action: min_max
```
Here we've accounted for the following:
- PushT uses "observation.image" for its image key.
- PushT provides smaller images.
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
```bash
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
```
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
```bash
python lerobot/scripts/train.py policy=act_pusht env=pusht
```
Notice that this is much the same as the command that failed at the start of the tutorial, only:
- Now we are using `policy=act_pusht` to point to our new configuration file.
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
Hurrah! You're now training ACT for the PushT environment.
---
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
Happy coding! 🤗

1
gym_dora/README.md Normal file
View File

@@ -0,0 +1 @@
# gym_dora

17
gym_dora/example.py Normal file
View File

@@ -0,0 +1,17 @@
import gymnasium as gym
import gym_dora # noqa: F401
env = gym.make("gym_dora/DoraAloha-v0", disable_env_checker=True)
obs = env.reset()
policy = ... # make_policy
done = False
while not done:
actions = policy.select_action(obs)
observation, reward, terminated, truncated, info = env.step(actions)
done = terminated | truncated | done
env.close()

View File

@@ -0,0 +1,17 @@
from gymnasium.envs.registration import register
register(
id="gym_dora/DoraAloha-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "aloha"},
)
register(
id="gym_dora/DoraKoch-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "koch"},
)

199
gym_dora/gym_dora/env.py Normal file
View File

@@ -0,0 +1,199 @@
import os
import gymnasium as gym
import numpy as np
import pyarrow as pa
from dora import Node
from gymnasium import spaces
FPS = int(os.getenv("FPS", "30"))
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
ALOHA_JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
ALOHA_ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"left_arm_gripper",
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"right_arm_gripper",
]
class DoraEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
def __init__(
self,
model="aloha",
observation_width=IMAGE_WIDTH,
observation_height=IMAGE_HEIGHT,
cameras_names=None,
num_joints=None,
num_actions=None,
):
"""Initializes the Dora environment.
Args:
model (str): The model to use. Either 'aloha' or 'custom'.
observation_width (int): The width of the observation image.
observation_height (int): The height of the observation image.
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
"""
super().__init__()
# Initialize a new node
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
self.observation = {"pixels": {}, "agent_pos": None}
self.terminated = False
self.observation_height = observation_height
self.observation_width = observation_width
# Observation space
if model == "aloha":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(
{
"cam_high": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_low": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_left_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_right_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
}
),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(len(ALOHA_JOINTS),),
dtype=np.float64,
),
}
)
elif model == "custom":
pixel_dict = {}
for camera in cameras_names:
assert camera.startswith("cam"), "Camera names must start with 'cam'"
pixel_dict[camera] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(pixel_dict),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
),
}
)
else:
raise ValueError("Model must be either 'aloha' or 'custom'.")
# Action space
if model == "aloha":
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
elif model == "custom":
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
def _get_obs(self):
while True:
event = self.node.next(timeout=0.001)
## If event is None, the node event stream is closed and we should terminate the env
if event is None:
self.terminated = True
break
if event["type"] == "INPUT":
# Map Image input into pixels key within Aloha environment
if "cam" in event["id"]:
self.observation["pixels"][event["id"]] = (
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
)
else:
# Map other inputs into the observation dictionary using the event id as key
self.observation[event["id"]] = event["value"].to_numpy()
# If the event is a timeout error break the update loop.
elif event["type"] == "ERROR":
break
def reset(self, seed: int | None = None):
self.node.send_output("reset")
self._get_obs()
self.terminated = False
info = {}
return self.observation, info
def step(self, action: np.ndarray):
# Send the action to the dataflow as action key.
self.node.send_output("action", pa.array(action))
self._get_obs()
reward = 0
terminated = truncated = self.terminated
info = {}
return self.observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
# Drop the node
del self.node

182
gym_dora/poetry.lock generated Normal file
View File

@@ -0,0 +1,182 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "cloudpickle"
version = "3.0.0"
description = "Pickler class to extend the standard pickle.Pickler functionality"
optional = false
python-versions = ">=3.8"
files = [
{file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"},
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
]
[[package]]
name = "dora-rs"
version = "0.3.4"
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
optional = false
python-versions = "*"
files = [
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
]
[package.dependencies]
pyarrow = "*"
[[package]]
name = "farama-notifications"
version = "0.0.4"
description = "Notifications for all Farama Foundation maintained libraries."
optional = false
python-versions = "*"
files = [
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
]
[[package]]
name = "gymnasium"
version = "0.29.1"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
optional = false
python-versions = ">=3.8"
files = [
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
]
[package.dependencies]
cloudpickle = ">=1.2.0"
farama-notifications = ">=0.0.1"
numpy = ">=1.21.0"
typing-extensions = ">=4.3.0"
[package.extras]
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
name = "pyarrow"
version = "16.1.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
]
[package.dependencies]
numpy = ">=1.16.6"
[[package]]
name = "typing-extensions"
version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "7e437b5c547ebe11095f1ce4ff1851d636f8e707ad7de8a6224b0f9ad978240f"

17
gym_dora/pyproject.toml Normal file
View File

@@ -0,0 +1,17 @@
[tool.poetry]
name = "gym-dora"
version = "0.1.0"
description = ""
authors = ["Simon Alibert <alibert.sim@gmail.com>"]
readme = "README.md"
packages = [{ include = "gym_dora" }]
[tool.poetry.dependencies]
python = "^3.10"
gymnasium = ">=0.29.1"
dora-rs = ">=0.3.4"
pyarrow = ">=12.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -52,7 +52,6 @@ available_tasks_per_env = {
], ],
"pusht": ["PushT-v0"], "pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"], "xarm": ["XarmLift-v0"],
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
} }
available_envs = list(available_tasks_per_env.keys()) available_envs = list(available_tasks_per_env.keys())
@@ -78,23 +77,6 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image", "lerobot/xarm_push_medium_replay_image",
], ],
"dora": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
} }
available_real_world_datasets = [ available_real_world_datasets = [
@@ -134,7 +116,6 @@ available_policies = [
available_policies_per_env = { available_policies_per_env = {
"aloha": ["act"], "aloha": ["act"],
"dora": ["act"],
"pusht": ["diffusion"], "pusht": ["diffusion"],
"xarm": ["tdmpc"], "xarm": ["tdmpc"],
} }

View File

@@ -21,20 +21,6 @@ from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def resolve_delta_timestamps(cfg):
"""Resolves delta_timestamps config key (in-place) by using `eval`.
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
the data type of its values).
"""
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset( def make_dataset(
cfg, cfg,
split="train", split="train",
@@ -45,14 +31,18 @@ def make_dataset(
f"environment ({cfg.env.name=})." f"environment ({cfg.env.name=})."
) )
resolve_delta_timestamps(cfg) delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations # TODO(rcadene): add data augmentations
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=delta_timestamps,
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@@ -18,7 +18,6 @@ Contains utilities to process raw data format from dora-record
""" """
import logging import logging
import re
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
@@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
return True return True
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): def load_from_raw(raw_dir: Path, out_dir: Path):
# Load data stream that will be used as reference for the timestamps synchronization # Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0: if len(reference_files) == 0:
@@ -57,59 +56,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
key = path.stem # action or observation.state or ... key = path.stem # action or observation.state or ...
if key == reference_key: if key == reference_key:
continue continue
if "failed_episode_index" in key:
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
continue
modality_df = pd.read_parquet(path) modality_df = pd.read_parquet(path)
modality_df = modality_df[["timestamp_utc", key]] modality_df = modality_df[["timestamp_utc", key]]
df = pd.merge_asof( df = pd.merge_asof(
df, df,
modality_df, modality_df,
on="timestamp_utc", on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by direction="backward",
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"),
) )
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
# because some cameras didnt start recording yet.
df = df.dropna(axis=0)
# Remove rows with episode_index -1 which indicates a failed episode
df = df[df["episode_index"] != -1] df = df[df["episode_index"] != -1]
image_keys = [key for key in df if "observation.images." in key]
num_unaligned_images = 0
max_episode = 0
def get_episode_index(row):
nonlocal num_unaligned_images
nonlocal max_episode
episode_index_per_cam = {}
for key in image_keys:
if isinstance(row[key], float):
num_unaligned_images += 1
return float("nan")
path = row[key][0]["path"]
match = re.search(r"_(\d{6}).mp4", path)
if not match:
raise ValueError(path)
episode_index = int(match.group(1))
episode_index_per_cam[key] = episode_index
if episode_index > max_episode:
assert episode_index - max_episode == 1
max_episode = episode_index
else:
assert episode_index == max_episode
if len(set(episode_index_per_cam.values())) != 1:
raise ValueError(
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
)
return episode_index
df["episode_index"] = df.apply(get_episode_index, axis=1)
# dora only use arrays, so single values are encapsulated into a list # dora only use arrays, so single values are encapsulated into a list
df["episode_index"] = df["episode_index"].map(lambda x: x[0])
df["frame_index"] = df.groupby("episode_index").cumcount() df["frame_index"] = df.groupby("episode_index").cumcount()
df = df.reset_index() df = df.reset_index()
df["index"] = df.index df["index"] = df.index
@@ -124,29 +88,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
del df["timestamp_utc"] del df["timestamp_utc"]
# sanity check
num_rows_with_nan = df.isna().any(axis=1).sum()
assert (
num_rows_with_nan == num_unaligned_images
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
if num_unaligned_images > max_episode * 2:
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
# but if there are too many, we raise an error to avoid large chunks of missing data
raise ValueError(
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
f"Check the timestamps of the cameras."
)
# Drop rows with NaN values now that we double checked and convert episode_index to int
df = df.dropna()
df["episode_index"] = df["episode_index"].astype(int)
# sanity check episode indices go from 0 to n-1 # sanity check episode indices go from 0 to n-1
assert df["episode_index"].max() == max_episode
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1)) expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids: assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
# Create symlink to raw videos directory (that needs to be absolute not relative) # Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
@@ -159,8 +104,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
continue continue
for ep_idx in ep_ids: for ep_idx in ep_ids:
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
if not video_path.exists(): assert video_path.exists(), f"Video file not found in {video_path}"
raise ValueError(f"Video file not found in {video_path}")
data_dict = {} data_dict = {}
for key in df: for key in df:
@@ -172,8 +116,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
# sanity check the video path is well formated # sanity check the video path is well formated
video_path = videos_dir.parent / data_dict[key][0]["path"] video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists(): assert video_path.exists(), f"Video file not found in {video_path}"
raise ValueError(f"Video file not found in {video_path}")
# is number # is number
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1: elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
data_dict[key] = torch.from_numpy(df[key].values) data_dict[key] = torch.from_numpy(df[key].values)
@@ -241,11 +184,13 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if fps is None: if fps is None:
fps = 30 fps = 30
else:
raise NotImplementedError()
if not video: if not video:
raise NotImplementedError() raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
hf_dataset = to_hf_dataset(data_df, video) hf_dataset = to_hf_dataset(data_df, video)
info = { info = {

View File

@@ -27,6 +27,14 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if n_envs is not None and n_envs < 1: if n_envs is not None and n_envs < 1:
raise ValueError("`n_envs must be at least 1") raise ValueError("`n_envs must be at least 1")
kwargs = {
# "obs_type": "pixels_agent_pos",
# "render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
# "visualization_width": 384,
# "visualization_height": 384,
}
package_name = f"gym_{cfg.env.name}" package_name = f"gym_{cfg.env.name}"
try: try:
@@ -38,16 +46,12 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
raise e raise e
gym_handle = f"{package_name}/{cfg.env.task}" gym_handle = f"{package_name}/{cfg.env.task}"
gym_kwgs = dict(cfg.env.get("gym", {}))
if cfg.env.get("episode_length"):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c) # batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env = env_cls( env = env_cls(
[ [
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
] ]
) )

View File

@@ -13,33 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file # TODO(rcadene, alexander-soare): clean this file
""" """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
import logging import logging
import os import os
import re
from glob import glob
from pathlib import Path from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir): def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str: def cfg_to_group(cfg, return_list=False):
"""Return a group name for logging. Optionally returns group name as list.""" """Return a group name for logging. Optionally returns group name as list."""
lst = [ lst = [
f"policy:{cfg.policy.name}", f"policy:{cfg.policy.name}",
@@ -50,54 +42,22 @@ def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
return lst if return_list else "-".join(lst) return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger: class Logger:
"""Primary logger object. Logs either locally or using wandb. """Primary logger object. Logs either locally or using wandb."""
The logger creates the following directory structure: def __init__(self, log_dir, job_name, cfg):
self._log_dir = Path(log_dir)
provided_log_dir self._log_dir.mkdir(parents=True, exist_ok=True)
├── .hydra # hydra's configuration cache self._job_name = job_name
├── checkpoints self._model_dir = self._log_dir / "checkpoints"
├── specific_checkpoint_name self._buffer_dir = self._log_dir / "buffers"
│ ├── pretrained_model # Hugging Face pretrained model directory self._save_model = cfg.training.save_model
│ │ ├── ... self._disable_wandb_artifact = cfg.wandb.disable_artifact
│ └── training_state.pth # optimizer, scheduler, and random states + training step self._save_buffer = cfg.training.get("save_buffer", False)
| ├── another_specific_checkpoint_name
│ │ ├── ...
| ├── ...
│ └── last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
project = cfg.get("wandb", {}).get("project") project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity") entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False) enable_wandb = cfg.get("wandb", {}).get("enable", False)
@@ -109,127 +69,65 @@ class Logger:
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
import wandb import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init( wandb.init(
id=wandb_run_id,
project=project, project=project,
entity=entity, entity=entity,
name=wandb_job_name, name=job_name,
notes=cfg.get("wandb", {}).get("notes"), notes=cfg.get("wandb", {}).get("notes"),
# group=self._group,
tags=cfg_to_group(cfg, return_list=True), tags=cfg_to_group(cfg, return_list=True),
dir=log_dir, dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True), config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True # TODO(rcadene): try set to True
save_code=False, save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval" # TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval", job_type="train_eval",
resume="must" if cfg.resume else None, # TODO(rcadene): add resume option
resume=None,
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
@classmethod def save_model(self, policy: Policy, identifier):
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path: if self._save_model:
"""Given the log directory, get the sub-directory in which checkpoints will be saved.""" self._model_dir.mkdir(parents=True, exist_ok=True)
return Path(log_dir) / "checkpoints" save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
@classmethod def save_buffer(self, buffer, identifier):
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path: self._buffer_dir.mkdir(parents=True, exist_ok=True)
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved.""" fp = self._buffer_dir / f"{str(identifier)}.pkl"
return cls.get_checkpoints_dir(log_dir) / "last" buffer.save(fp)
if self._wandb and not self._disable_wandb_artifact:
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name # note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model") artifact = self._wandb.Artifact(
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
self._wandb.log_artifact(artifact) type="buffer",
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
scheduler: LRScheduler | None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpont(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
) )
# Small hack to get the expected keys: use `get_global_random_state`. artifact.add_file(fp)
set_global_random_state({k: training_state[k] for k in get_global_random_state()}) self._wandb.log_artifact(artifact)
return training_state["step"]
def finish(self, agent, buffer):
if self._save_model:
self.save_model(agent, identifier="final")
if self._save_buffer:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
def log_dict(self, d, step, mode="train"): def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None: if self._wandb is not None:
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)): if not isinstance(v, (int, float, str)):

View File

@@ -25,14 +25,6 @@ class ACTConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`. Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
@@ -41,15 +33,15 @@ class ACTConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out. environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents input_shapes: A dictionary defining the shapes of the input data for the policy.
the input data name, and the value is a list indicating the dimensions of the corresponding data. The key represents the input data name, and the value is a list indicating the dimensions
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], of the corresponding data. For example, "observation.images.top" refers to an input from the
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
include batch dimension or temporal dimension. Importantly, shapes doesn't include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents output_shapes: A dictionary defining the shapes of the output data for the policy.
the output data name, and the value is a list indicating the dimensions of the corresponding data. The key represents the output data name, and the value is a list indicating the dimensions
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. of the corresponding data. For example, "action" refers to an output shape of [14], indicating
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -200,28 +200,25 @@ class ACT(nn.Module):
self.config = config self.config = config
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.has_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
if self.has_state: self.vae_encoder_robot_state_input_proj = nn.Linear(
self.vae_encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model
config.input_shapes["observation.state"][0], config.dim_model )
)
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear( self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
@@ -241,17 +238,15 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
if self.has_state: self.encoder_robot_state_input_proj = nn.Linear(
self.encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model
config.input_shapes["observation.state"][0], config.dim_model )
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.has_state else 1 self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
@@ -290,7 +285,7 @@ class ACT(nn.Module):
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.images"].shape[0] batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@@ -298,16 +293,11 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.has_state: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) 1
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) ) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
if self.has_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding. # Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
@@ -327,7 +317,6 @@ class ACT(nn.Module):
else: else:
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device batch["observation.state"].device
) )
@@ -337,10 +326,8 @@ class ACT(nn.Module):
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
images = batch["observation.images"] images = batch["observation.images"]
for cam_index in range(images.shape[-4]): for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
@@ -350,15 +337,13 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
if self.has_state: robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack(encoder_in_feats, axis=0), torch.stack([latent_embed, robot_state_embed], axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"), einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
] ]
) )
@@ -372,7 +357,6 @@ class ACT(nn.Module):
# Forward pass through the transformer modules. # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model), (self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype, dtype=pos_embed.dtype,

View File

@@ -26,29 +26,21 @@ class DiffusionConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`. Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details. See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents input_shapes: A dictionary defining the shapes of the input data for the policy.
the input data name, and the value is a list indicating the dimensions of the corresponding data. The key represents the input data name, and the value is a list indicating the dimensions
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], of the corresponding data. For example, "observation.image" refers to an input from
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
include batch dimension or temporal dimension. Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents output_shapes: A dictionary defining the shapes of the output data for the policy.
the output data name, and the value is a list indicating the dimensions of the corresponding data. The key represents the output data name, and the value is a list indicating the dimensions
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. of the corresponding data. For example, "action" refers to an output shape of [14], indicating
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
@@ -163,7 +155,7 @@ class DiffusionConfig:
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
) )
image_key = next(iter(image_keys)) image_key = next(iter(image_keys))
if self.crop_shape is not None and ( if (
self.crop_shape[0] > self.input_shapes[image_key][1] self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2] or self.crop_shape[1] > self.input_shapes[image_key][2]
): ):

View File

@@ -304,11 +304,7 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory). # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding: if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
if "action_is_pad" not in batch:
raise ValueError(
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"] in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1) loss = loss * in_episode_bound.unsqueeze(-1)
@@ -427,15 +423,11 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should # The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # use the height and width from `config.crop_shape`.
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1 assert len(image_keys) == 1
image_key = image_keys[0] image_key = image_keys[0]
dummy_input_h_w = ( dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode(): with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input) dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:]) feature_map_shape = tuple(dummy_feature_map.shape[1:])

View File

@@ -31,15 +31,6 @@ class TDMPCConfig:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot) action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control. horizon: Horizon for model predictive control.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -19,7 +19,7 @@ import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Generator from typing import Generator
import hydra import hydra
import numpy as np import numpy as np
@@ -48,38 +48,12 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device return device
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed): def set_global_seed(seed):
"""Set seed for reproducibility.""" """Set seed for reproducibility."""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed_all(seed)
@contextmanager @contextmanager
@@ -95,10 +69,16 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
c = random.random() # produces yet another random number, but the same it would have if we never made `b` c = random.random() # produces yet another random number, but the same it would have if we never made `b`
``` ```
""" """
random_state_dict = get_global_random_state() random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed) set_global_seed(seed)
yield None yield None
set_global_random_state(random_state_dict) random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging(): def init_logging():

View File

@@ -5,21 +5,11 @@ defaults:
hydra: hydra:
run: run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name} dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job: job:
name: default name: default
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: false
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???
@@ -27,7 +17,6 @@ dataset_repo_id: lerobot/pusht
training: training:
offline_steps: ??? offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ??? online_steps: ???
online_steps_between_rollouts: ??? online_steps_between_rollouts: ???
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5
@@ -36,9 +25,7 @@ training:
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???
log_freq: 250 log_freq: 250
save_checkpoint: true save_model: true
num_workers: 4
batch_size: ???
eval: eval:
n_episodes: 1 n_episodes: 1
@@ -49,7 +36,7 @@ eval:
wandb: wandb:
enable: false enable: false
# Set to true to disable saving an artifact despite save_checkpoint == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: false
project: lerobot project: lerobot
notes: "" notes: ""

View File

@@ -5,10 +5,10 @@ fps: 50
env: env:
name: aloha name: aloha
task: AlohaInsertion-v0 task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
episode_length: 400
fps: ${fps}
state_dim: 14 state_dim: 14
action_dim: 14 action_dim: 14
fps: ${fps}
episode_length: 400
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

14
lerobot/configs/env/dora.yaml vendored Normal file
View File

@@ -0,0 +1,14 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
# from_pixels: True
# pixels_only: False
# image_size: [3, 480, 640]
episode_length: 400
# fps: ${fps}
# state_dim: 14
# action_dim: 14

View File

@@ -1,13 +0,0 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@@ -5,13 +5,10 @@ fps: 10
env: env:
name: pusht name: pusht
task: PushT-v0 task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96 image_size: 96
episode_length: 300
fps: ${fps}
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2
fps: ${fps}
episode_length: 300
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@@ -5,13 +5,10 @@ fps: 15
env: env:
name: xarm name: xarm
task: XarmLift-v0 task: XarmLift-v0
from_pixels: True
pixels_only: False
image_size: 84 image_size: 84
episode_length: 25
fps: ${fps}
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4
fps: ${fps}
episode_length: 25
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@@ -15,7 +15,7 @@ training:
eval_freq: 10000 eval_freq: 10000
save_freq: 100000 save_freq: 100000
log_freq: 250 log_freq: 250
save_checkpoint: true save_model: true
batch_size: 8 batch_size: 8
lr: 1e-5 lr: 1e-5

View File

@@ -1,115 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,19 +1,7 @@
# @package _global_ # @package _global_
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
# overfits to the state, because the images are more complex to learn from since they are moving.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=aloha_real
# ```
seed: 1000 seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup dataset_repo_id: cadene/aloha_v2_static_dora_test
override_dataset_stats: override_dataset_stats:
observation.images.cam_right_wrist: observation.images.cam_right_wrist:
@@ -36,10 +24,10 @@ override_dataset_stats:
training: training:
offline_steps: 80000 offline_steps: 80000
online_steps: 0 online_steps: 0
eval_freq: -1 eval_freq: 99999999999999
save_freq: 10000 save_freq: 1000
log_freq: 100 log_freq: 100
save_checkpoint: true save_model: true
batch_size: 8 batch_size: 8
lr: 1e-5 lr: 1e-5
@@ -70,6 +58,7 @@ policy:
observation.images.cam_left_wrist: [3, 480, 640] observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640] observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640] observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes: output_shapes:
action: ["${env.action_dim}"] action: ["${env.action_dim}"]
@@ -79,6 +68,7 @@ policy:
observation.images.cam_left_wrist: mean_std observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std observation.images.cam_high: mean_std
observation.images.cam_low: mean_std observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes: output_normalization_modes:
action: mean_std action: mean_std

View File

@@ -27,7 +27,7 @@ training:
eval_freq: 5000 eval_freq: 5000
save_freq: 5000 save_freq: 5000
log_freq: 250 log_freq: 250
save_checkpoint: true save_model: true
batch_size: 64 batch_size: 64
grad_clip_norm: 10 grad_clip_norm: 10

View File

@@ -5,8 +5,7 @@ dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000
# TODO(alexander-soare): uncomment when online training gets reinstated online_steps: 25000
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000 eval_freq: 5000
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@@ -28,7 +28,7 @@ OR, you want to evaluate a model checkpoint from the LeRobot training script for
``` ```
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \ -p outputs/train/diffusion_pusht/checkpoints/005000 \
eval.n_episodes=10 eval.n_episodes=10
``` ```
@@ -46,7 +46,6 @@ import json
import logging import logging
import threading import threading
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@@ -521,7 +520,7 @@ def eval(
raise NotImplementedError() raise NotImplementedError()
# Check device is available # Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True) get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@@ -540,17 +539,16 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): info = eval_policy(
info = eval_policy( env,
env, policy,
policy, hydra_cfg.eval.n_episodes,
hydra_cfg.eval.n_episodes, max_episodes_rendered=10,
max_episodes_rendered=10, video_dir=Path(out_dir) / "eval",
video_dir=Path(out_dir) / "eval", start_seed=hydra_cfg.seed,
start_seed=hydra_cfg.seed, enable_progbar=True,
enable_progbar=True, enable_inner_progbar=True,
enable_inner_progbar=True, )
)
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info

View File

@@ -144,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
def push_dataset_to_hub( def push_dataset_to_hub(
data_dir: Path, input_data_dir: Path,
output_data_dir: Path,
dataset_id: str, dataset_id: str,
raw_format: str | None, raw_format: str | None,
community_id: str, community_id: str,
@@ -161,34 +162,33 @@ def push_dataset_to_hub(
): ):
repo_id = f"{community_id}/{dataset_id}" repo_id = f"{community_id}/{dataset_id}"
raw_dir = data_dir / f"{dataset_id}_raw" meta_data_dir = output_data_dir / "meta_data"
videos_dir = output_data_dir / "videos"
out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
tests_out_dir = tests_data_dir / repo_id tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data" tests_meta_data_dir = tests_out_dir / "meta_data"
tests_videos_dir = tests_out_dir / "videos" tests_videos_dir = tests_out_dir / "videos"
if out_dir.exists(): if output_data_dir.exists():
shutil.rmtree(out_dir) shutil.rmtree(output_data_dir)
if tests_out_dir.exists() and save_tests_to_disk: if tests_out_dir.exists() and save_tests_to_disk:
shutil.rmtree(tests_out_dir) shutil.rmtree(tests_out_dir)
if not raw_dir.exists(): if not input_data_dir.exists():
download_raw(raw_dir, dataset_id) download_raw(input_data_dir, dataset_id)
if raw_format is None: if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format # TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError() raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir) # raw_format = auto_find_raw_format(input_data_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format # convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
input_data_dir, output_data_dir, fps, video, debug
)
lerobot_dataset = LeRobotDataset.from_preloaded( lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id, repo_id=repo_id,
@@ -202,7 +202,7 @@ def push_dataset_to_hub(
if save_to_disk: if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train")) hf_dataset.save_to_disk(str(output_data_dir / "train"))
if not dry_run or save_to_disk: if not dry_run or save_to_disk:
# mandatory for upload # mandatory for upload
@@ -236,19 +236,25 @@ def push_dataset_to_hub(
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname) shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists(): if not save_to_disk and output_data_dir.exists():
# remove possible temporary files remaining in the output directory # remove possible temporary files remaining in the output directory
shutil.rmtree(out_dir) shutil.rmtree(output_data_dir)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--data-dir", "--input-data-dir",
type=Path, type=Path,
required=True, required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw`).",
)
parser.add_argument(
"--output-data-dir",
type=Path,
required=True,
help="Root directory containing output dataset (e.g. `data/lerobot/aloha_mobile_chair` or `data/lerobot/pusht`).",
) )
parser.add_argument( parser.add_argument(
"--dataset-id", "--dataset-id",

View File

@@ -15,29 +15,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from pprint import pformat
import datasets
import hydra import hydra
import torch import torch
from deepdiff import DeepDiff from datasets import concatenate_datasets
from omegaconf import DictConfig, OmegaConf from datasets.utils import disable_progress_bars, enable_progress_bars
from termcolor import colored from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_seed, set_global_seed,
) )
@@ -73,6 +69,7 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps, cfg.training.adam_eps,
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
@@ -90,40 +87,21 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy( def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
start_time = time.perf_counter() start_time = time.time()
device = get_device_from_parameters(policy)
policy.train() policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext(): output_dict = policy.forward(batch)
output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) loss = output_dict["loss"]
loss = output_dict["loss"] loss.backward()
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them, optimizer.step()
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@@ -137,13 +115,31 @@ def update_policy(
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time, "update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"}, **{k: v for k, v in output_dict.items() if k != "loss"},
} }
return info return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"] loss = info["loss"]
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
@@ -215,6 +211,103 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@@ -223,91 +316,35 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
# to check for any differences between the provided config and the checkpoint's config. logging.warning("eval.batch_size > 1 not supported for online training steps")
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
)
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
if cfg.training.online_steps > 0:
raise NotImplementedError("Online training is not implemented yet.")
set_global_seed(cfg.seed)
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_dataset") logging.info("make_dataset")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data. logging.info("make_env")
# On real-world data, no need to create an environment as evaluations are done outside train.py, eval_env = make_env(cfg)
# using the eval.py instead, with gym_dora environment and dora-rs.
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy( policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step = logger.load_last_training_state(optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
@@ -319,31 +356,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops. # Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy(
eval_info = eval_policy( eval_env,
eval_env, policy,
policy, cfg.eval.n_episodes,
cfg.eval.n_episodes, video_dir=Path(out_dir) / "eval",
video_dir=Path(out_dir) / "eval", max_episodes_rendered=4,
max_episodes_rendered=4, start_seed=cfg.seed,
start_seed=cfg.seed, )
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training") logging.info("Resume training")
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0: if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill). # needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpont( logger.save_model(
step,
policy, policy,
optimizer,
lr_scheduler,
identifier=str(step).zfill( identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
), ),
@@ -353,34 +386,28 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# create dataloader for offline training # create dataloader for offline training
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
offline_dataset, offline_dataset,
num_workers=cfg.training.num_workers, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=device.type != "cpu", pin_memory=cfg.device != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for _ in range(step, cfg.training.offline_steps): for offline_step in range(cfg.training.offline_steps):
if step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy( train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
@@ -390,13 +417,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1 step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout # create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1) online_training_env = make_env(cfg, n_envs=1)
@@ -416,33 +436,59 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=device.type != "cpu", pin_memory=cfg.device != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader)
logging.info("End of online training") online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
if cfg.training.eval_freq > 0: policy.eval()
eval_env.close() with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close()
online_training_env.close() online_training_env.close()
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__": if __name__ == "__main__":

133
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@@ -595,24 +595,6 @@ files = [
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
] ]
[[package]]
name = "deepdiff"
version = "7.0.1"
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
optional = false
python-versions = ">=3.8"
files = [
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
]
[package.dependencies]
ordered-set = ">=4.1.0,<4.2.0"
[package.extras]
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
optimize = ["orjson"]
[[package]] [[package]]
name = "diffusers" name = "diffusers"
version = "0.27.2" version = "0.27.2"
@@ -1093,7 +1075,7 @@ description = ""
optional = true optional = true
python-versions = "^3.10" python-versions = "^3.10"
files = [] files = []
develop = false develop = true
[package.dependencies] [package.dependencies]
dora-rs = ">=0.3.4" dora-rs = ">=0.3.4"
@@ -1101,11 +1083,8 @@ gymnasium = ">=0.29.1"
pyarrow = ">=12.0.0" pyarrow = ">=12.0.0"
[package.source] [package.source]
type = "git" type = "directory"
url = "https://github.com/dora-rs/dora-lerobot.git" url = "gym_dora"
reference = "HEAD"
resolved_reference = "1c6c2a401c3a2967d41444be6286ca9a28893abf"
subdirectory = "gym_dora"
[[package]] [[package]]
name = "gym-pusht" name = "gym-pusht"
@@ -1310,13 +1289,13 @@ files = [
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.23.1" version = "0.23.0"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, {file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, {file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
] ]
[package.dependencies] [package.dependencies]
@@ -2355,13 +2334,13 @@ files = [
[[package]] [[package]]
name = "nvidia-nvjitlink-cu12" name = "nvidia-nvjitlink-cu12"
version = "12.5.40" version = "12.4.127"
description = "Nvidia JIT LTO Library" description = "Nvidia JIT LTO Library"
optional = false optional = false
python-versions = ">=3" python-versions = ">=3"
files = [ files = [
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
] ]
[[package]] [[package]]
@@ -2414,20 +2393,6 @@ numpy = [
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
] ]
[[package]]
name = "ordered-set"
version = "4.1.0"
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
optional = false
python-versions = ">=3.7"
files = [
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
]
[package.extras]
dev = ["black", "mypy", "pytest"]
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "24.0" version = "24.0"
@@ -3231,13 +3196,13 @@ files = [
[[package]] [[package]]
name = "requests" name = "requests"
version = "2.32.2" version = "2.32.1"
description = "Python HTTP for Humans." description = "Python HTTP for Humans."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, {file = "requests-2.32.1-py3-none-any.whl", hash = "sha256:21ac9465cdf8c1650fe1ecde8a71669a93d4e6f147550483a2967d08396a56a5"},
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, {file = "requests-2.32.1.tar.gz", hash = "sha256:eb97e87e64c79e64e5b8ac75cee9dd1f97f49e289b083ee6be96268930725685"},
] ]
[package.dependencies] [package.dependencies]
@@ -3442,36 +3407,36 @@ test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-
[[package]] [[package]]
name = "scipy" name = "scipy"
version = "1.13.1" version = "1.13.0"
description = "Fundamental algorithms for scientific computing in Python" description = "Fundamental algorithms for scientific computing in Python"
optional = true optional = true
python-versions = ">=3.9" python-versions = ">=3.9"
files = [ files = [
{file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
{file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
{file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
{file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
{file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
{file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
{file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
{file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
{file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
{file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
{file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
{file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
{file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
{file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
{file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
{file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
{file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
] ]
[package.dependencies] [package.dependencies]
@@ -3484,13 +3449,13 @@ test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "po
[[package]] [[package]]
name = "sentry-sdk" name = "sentry-sdk"
version = "2.3.1" version = "2.2.1"
description = "Python client for Sentry (https://sentry.io)" description = "Python client for Sentry (https://sentry.io)"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"}, {file = "sentry_sdk-2.2.1-py2.py3-none-any.whl", hash = "sha256:7d617a1b30e80c41f3b542347651fcf90bb0a36f3a398be58b4f06b79c8d85bc"},
{file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"}, {file = "sentry_sdk-2.2.1.tar.gz", hash = "sha256:8aa2ec825724d8d9d645cab68e6034928b1a6a148503af3e361db3fa6401183f"},
] ]
[package.dependencies] [package.dependencies]
@@ -3780,13 +3745,13 @@ tests = ["pytest", "pytest-cov"]
[[package]] [[package]]
name = "tifffile" name = "tifffile"
version = "2024.5.22" version = "2024.5.10"
description = "Read and write TIFF files" description = "Read and write TIFF files"
optional = true optional = true
python-versions = ">=3.9" python-versions = ">=3.9"
files = [ files = [
{file = "tifffile-2024.5.22-py3-none-any.whl", hash = "sha256:e281781c15d7d197d7e12749849c965651413aa905f97a48b0f84bd90a3b4c6f"}, {file = "tifffile-2024.5.10-py3-none-any.whl", hash = "sha256:4154f091aa24d4e75bfad9ab2d5424a68c70e67b8220188066dc61946d4551bd"},
{file = "tifffile-2024.5.22.tar.gz", hash = "sha256:3a105801d1b86d55692a98812a170c39d3f0447aeacb1d94635d38077cb328c4"}, {file = "tifffile-2024.5.10.tar.gz", hash = "sha256:aa1e1b12be952ab20717d6848bd6d4a5ee88d2aa319f1152bff4354ad728ec86"},
] ]
[package.dependencies] [package.dependencies]
@@ -3942,13 +3907,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.12.0" version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+" description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
{file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
] ]
[[package]] [[package]]
@@ -4309,4 +4274,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" content-hash = "ea4e8207316a8ec8a4b95d6a89cf488c8733a8e7ab43e5f669c889ee87f3bef3"

View File

@@ -41,12 +41,12 @@ numba = ">=0.59.0"
torch = "^2.2.1" torch = "^2.2.1"
opencv-python = ">=4.9.0" opencv-python = ">=4.9.0"
diffusers = "^0.27.2" diffusers = "^0.27.2"
torchvision = ">=0.17.1" torchvision = ">=0.18.0"
h5py = ">=3.10.0" h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
gymnasium = ">=0.29.1" gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1" cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true } gym-dora = { path = "gym_dora", optional = true, develop = true}
gym-pusht = { version = ">=0.1.3", optional = true} gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true} gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true} gym-aloha = { version = ">=0.1.1", optional = true}
@@ -59,7 +59,6 @@ imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5" pyav = ">=12.0.5"
moviepy = ">=1.0.3" moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1" rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1"
[tool.poetry.extras] [tool.poetry.extras]

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ebd21273f6048b66c806f92035352843a9069908b3296863fd55d34cf71cd0ef
size 51248

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9bbf951891077320a5da27e77ddb580a6e833e8d3162b62a2f887a1989585cc
size 31688

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:42f92239223bb4df32d5c3016bc67450159f1285a7ab046307b645f699ccc34e
size 34928

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:52f85d6262ad1dd0b66578b25829fed96aaaca3c7458cb73ac75111350d17fcf
size 51248

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5ba7c910618f0f3ca69f82f3d70c880d2b2e432456524a2a63dfd5c50efa45f0
size 30808

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53ad410f43855254438790f54aa7c895a052776acdd922906ae430684f659b53
size 33608

View File

@@ -75,16 +75,15 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None dataset.delta_timestamps = None
batch = next(iter(dataloader)) batch = next(iter(dataloader))
obs = {} obs = {
for k in batch: k: batch[k]
if "observation" in k: for k in batch
obs[k] = batch[k] if k in ["observation.image", "observation.images.top", "observation.state"]
}
if "n_action_steps" in cfg.policy:
actions_queue = cfg.policy.n_action_steps
else:
actions_queue = cfg.policy.n_action_repeats
actions_queue = (
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
)
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions return output_dict, grad_stats, param_stats, actions
@@ -115,8 +114,6 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
), ),
("aloha", "act", ["policy.n_action_steps=10"]), ("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
] ]
for env, policy, extra_overrides in env_policies: for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@@ -45,11 +45,11 @@ def test_example_1():
@require_package("gym_pusht") @require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1(): def test_examples_2_through_4():
""" """
Train a model with example 3, check the outputs. Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs. Evaluate the trained model with example 2, check the outputs.
Calculate the validation loss with advanced example 1, check the outputs. Calculate the validation loss with example 4, check the outputs.
""" """
### Test example 3 ### Test example 3
@@ -97,7 +97,7 @@ def test_examples_basic2_basic3_advanced1():
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
## Test example 4 ## Test example 4
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py") file_contents = _read_file("examples/4_calculate_validation_loss.py")
# Run on a single example from the last episode, use CPU, and use the local model. # Run on a single example from the last episode, use CPU, and use the local model.
file_contents = _find_and_replace( file_contents = _find_and_replace(

View File

@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -72,8 +72,6 @@ def test_get_policy_and_config_classes(policy_name: str):
), ),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config. # Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
], ],
) )
@require_env @require_env
@@ -293,8 +291,6 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
), ),
("aloha", "act", ["policy.n_action_steps=10"]), ("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
], ],
) )
# As artifacts have been generated on an x86_64 kernel, this test won't # As artifacts have been generated on an x86_64 kernel, this test won't

View File

@@ -11,24 +11,22 @@ from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
reset_episode_index, reset_episode_index,
) )
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import seeded_context, set_global_seed
get_global_random_state,
seeded_context,
set_global_random_state, @pytest.mark.parametrize(
set_global_seed, "rand_fn",
(
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else []
),
) )
# Random generation functions for testing the seeding and random state get/set.
rand_fns = [
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
if torch.cuda.is_available():
rand_fns.append(lambda: torch.rand(1, device="cuda"))
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]): def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0) set_global_seed(0)
a = rand_fn() a = rand_fn()
@@ -48,15 +46,6 @@ def test_seeding(rand_fn: Callable[[], int]):
assert c_ == c assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index(): def test_calculate_episode_data_index():
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {