forked from tangger/lerobot
Compare commits
39 Commits
thom_arm
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b502a82005 | ||
|
|
12a1b8f55a | ||
|
|
125bd93e29 | ||
|
|
c38f535c9f | ||
|
|
ff8f6aa6cd | ||
|
|
1cf050d412 | ||
|
|
54c9776bde | ||
|
|
a06598678c | ||
|
|
055a6f60c6 | ||
|
|
e54d6ea1eb | ||
|
|
1eb4bfe2e4 | ||
|
|
21f222fa1d | ||
|
|
33362dbd17 | ||
|
|
b0d954c6e1 | ||
|
|
bd3111f28b | ||
|
|
cf15cba5fc | ||
|
|
205e0c9dde | ||
|
|
5b74205e16 | ||
|
|
042e193995 | ||
|
|
d585c73f9f | ||
|
|
504d2aaf48 | ||
|
|
83f4f7f7e8 | ||
|
|
633115d861 | ||
|
|
57fb5fe8a6 | ||
|
|
0b51a335bc | ||
|
|
111cd58f8a | ||
|
|
265b0ec44d | ||
|
|
2c2e4e14ed | ||
|
|
13310681b1 | ||
|
|
3d625ae6d3 | ||
|
|
e3b9f1c19b | ||
|
|
7ec76ee235 | ||
|
|
3b86050ab0 | ||
|
|
6d39b73399 | ||
|
|
aca424a481 | ||
|
|
35c1ce7a66 | ||
|
|
e67da1d7a6 | ||
|
|
b6c216b590 | ||
|
|
2b270d085b |
98
.github/workflows/build-docker-images.yml
vendored
98
.github/workflows/build-docker-images.yml
vendored
@@ -10,7 +10,6 @@ on:
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
latest-cpu:
|
||||
@@ -35,6 +34,8 @@ jobs:
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -51,34 +52,50 @@ jobs:
|
||||
tags: huggingface/lerobot-cpu
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
|
||||
# - name: Post to a Slack channel
|
||||
# id: slack
|
||||
# #uses: slackapi/slack-github-action@v1.25.0
|
||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
# with:
|
||||
# # Slack channel id, channel name, or user id to post message.
|
||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# # For posting a rich message using Block Kit
|
||||
# payload: |
|
||||
# {
|
||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
# "blocks": [
|
||||
# {
|
||||
# "type": "section",
|
||||
# "text": {
|
||||
# "type": "mrkdwn",
|
||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# env:
|
||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
latest-cuda:
|
||||
name: GPU
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/lerobot-gpu
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
|
||||
|
||||
latest-cuda-dev:
|
||||
name: GPU Dev
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
@@ -104,36 +121,11 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
- name: Build and Push GPU dev
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/lerobot-gpu
|
||||
tags: huggingface/lerobot-gpu:dev
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
|
||||
# - name: Post to a Slack channel
|
||||
# id: slack
|
||||
# #uses: slackapi/slack-github-action@v1.25.0
|
||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
# with:
|
||||
# # Slack channel id, channel name, or user id to post message.
|
||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# # For posting a rich message using Block Kit
|
||||
# payload: |
|
||||
# {
|
||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
# "blocks": [
|
||||
# {
|
||||
# "type": "section",
|
||||
# "text": {
|
||||
# "type": "mrkdwn",
|
||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# env:
|
||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
2
.github/workflows/nightly-tests.yml
vendored
2
.github/workflows/nightly-tests.yml
vendored
@@ -70,6 +70,8 @@ jobs:
|
||||
# files: ./coverage.xml
|
||||
# verbose: true
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DEVICE: cuda
|
||||
run: make test-end-to-end
|
||||
|
||||
# - name: Generate Report
|
||||
|
||||
11
.github/workflows/test.yml
vendored
11
.github/workflows/test.yml
vendored
@@ -10,6 +10,7 @@ on:
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -19,6 +20,7 @@ on:
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
@@ -32,8 +34,8 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
@@ -70,6 +72,9 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
@@ -104,7 +109,7 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
- name: Install poetry
|
||||
|
||||
18
.github/workflows/trufflehog.yml
vendored
Normal file
18
.github/workflows/trufflehog.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
31
.gitignore
vendored
31
.gitignore
vendored
@@ -2,12 +2,17 @@
|
||||
logs
|
||||
tmp
|
||||
wandb
|
||||
|
||||
# Data
|
||||
data
|
||||
outputs
|
||||
.vscode
|
||||
rl
|
||||
|
||||
# Apple
|
||||
.DS_Store
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
*.key
|
||||
@@ -90,6 +95,7 @@ instance/
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
@@ -102,13 +108,6 @@ ipython_config.py
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
@@ -119,6 +118,14 @@ celerybeat.pid
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
@@ -136,3 +143,9 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
104
Makefile
104
Makefile
@@ -5,11 +5,12 @@ PYTHON_PATH := $(shell which python)
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
|
||||
DEVICE ?= cpu
|
||||
|
||||
build-cpu:
|
||||
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
|
||||
@@ -18,62 +19,101 @@ build-gpu:
|
||||
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} test-act-ete-train
|
||||
${MAKE} test-act-ete-eval
|
||||
${MAKE} test-diffusion-ete-train
|
||||
${MAKE} test-diffusion-ete-eval
|
||||
${MAKE} test-tdmpc-ete-train
|
||||
${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002 \
|
||||
-p tests/outputs/act/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-ete-train-amp:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
use_amp=true
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
policy.down_dims=\[64,128,256\] \
|
||||
policy.diffusion_step_embed_dim=32 \
|
||||
policy.num_inference_steps=10 \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/diffusion/checkpoints/000002 \
|
||||
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
@@ -82,24 +122,24 @@ test-tdmpc-ete-train:
|
||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002 \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
@@ -107,4 +147,22 @@ test-default-ete-eval:
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
python lerobot/scripts/train.py \
|
||||
policy=created_by_Makefile.yaml \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
|
||||
53
README.md
53
README.md
@@ -77,6 +77,10 @@ Install 🤗 LeRobot:
|
||||
pip install .
|
||||
```
|
||||
|
||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||
you may need to install `cmake` and `build-essential` for building some of our dependencies.
|
||||
On linux: `sudo apt-get install cmake build-essential`
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
@@ -99,6 +103,7 @@ wandb login
|
||||
```
|
||||
.
|
||||
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||
| └── advanced # contains even more examples for those who have mastered the basics
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
@@ -149,18 +154,19 @@ python lerobot/scripts/eval.py \
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
-p PATH/TO/TRAIN/OUTPUT/FOLDER
|
||||
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
@@ -174,6 +180,19 @@ The experiment directory is automatically generated and will show up in yellow i
|
||||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
|
||||
|
||||
```bash
|
||||
checkpoints
|
||||
├── 000250 # checkpoint_dir for training step 250
|
||||
│ ├── pretrained_model # Hugging Face pretrained model dir
|
||||
│ │ ├── config.json # Hugging Face pretrained model config
|
||||
│ │ ├── config.yaml # consolidated Hydra config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ └── README.md # Hugging Face model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
@@ -184,7 +203,19 @@ A link to the wandb logs for the run will also show up in yellow in your termina
|
||||
|
||||

|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
|
||||
|
||||
## Contribute
|
||||
|
||||
@@ -197,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
||||
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||
```bash
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot
|
||||
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||
--out-dir data \
|
||||
--repo-id lerobot/aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
@@ -215,14 +246,14 @@ If your dataset format is not supported, implement your own in `lerobot/common/d
|
||||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
|
||||
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
|
||||
40
docker/lerobot-gpu-dev/Dockerfile
Normal file
40
docker/lerobot-gpu-dev/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim less util-linux \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install gh cli tool
|
||||
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
&& wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
&& apt update \
|
||||
&& apt install gh -y \
|
||||
&& apt clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Setup `python`
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# Install poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | python -
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
|
||||
RUN poetry config virtualenvs.create false
|
||||
RUN poetry config virtualenvs.in-project true
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
@@ -4,18 +4,15 @@ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim ffmpeg \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget \
|
||||
tcpdump sysstat screen \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
@@ -23,8 +20,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
|
||||
183
examples/4_train_policy_with_script.md
Normal file
183
examples/4_train_policy_with_script.md
Normal file
@@ -0,0 +1,183 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
|
||||
- Makes a simulation environment.
|
||||
- Makes a dataset corresponding to that simulation environment.
|
||||
- Makes a policy.
|
||||
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
|
||||
|
||||
## Basics of how we use Hydra
|
||||
|
||||
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
|
||||
|
||||
First, `lerobot/configs` has a directory structure like this:
|
||||
|
||||
```
|
||||
.
|
||||
├── default.yaml
|
||||
├── env
|
||||
│ ├── aloha.yaml
|
||||
│ ├── pusht.yaml
|
||||
│ └── xarm.yaml
|
||||
└── policy
|
||||
├── act.yaml
|
||||
├── diffusion.yaml
|
||||
└── tdmpc.yaml
|
||||
```
|
||||
|
||||
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
|
||||
|
||||
When you run the training script with
|
||||
|
||||
```python
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
|
||||
|
||||
```yaml
|
||||
defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
```
|
||||
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
|
||||
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
There are two things to note here:
|
||||
- Config overrides are passed as `param_name=param_value`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||
|
||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||
|
||||
## Overriding configuration parameters in the CLI
|
||||
|
||||
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
task: AlohaInsertion-v0
|
||||
```
|
||||
|
||||
And if you look in `policy/act.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/policy/act.yaml
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
```
|
||||
|
||||
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
|
||||
|
||||
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
- task: AlohaInsertion-v0
|
||||
+ task: AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
Then, we'd also need to switch to using the cube transfer dataset.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/policy/act.yaml
|
||||
-dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
|
||||
```
|
||||
|
||||
Then, you'd be able to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
and you'd be training and evaluating on the cube transfer task.
|
||||
|
||||
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
|
||||
|
||||
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
|
||||
device=cuda
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
policy=act \
|
||||
training.eval_freq=10000 \
|
||||
training.log_freq=250 \
|
||||
training.offline_steps=100000 \
|
||||
training.save_model=true \
|
||||
training.save_freq=25000 \
|
||||
eval.n_episodes=50 \
|
||||
eval.batch_size=50 \
|
||||
wandb.enable=false \
|
||||
```
|
||||
|
||||
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
|
||||
|
||||
## Using a configuration file not in `lerobot/configs`
|
||||
|
||||
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
|
||||
```
|
||||
|
||||
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
|
||||
|
||||
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
|
||||
```
|
||||
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
|
||||
---
|
||||
|
||||
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
|
||||
|
||||
Or in the meantime, happy coding! 🤗
|
||||
37
examples/5_resume_training.md
Normal file
37
examples/5_resume_training.md
Normal file
@@ -0,0 +1,37 @@
|
||||
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
|
||||
|
||||
## Basic training resumption
|
||||
|
||||
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
training.log_freq=25 \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=100
|
||||
```
|
||||
|
||||
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
|
||||
|
||||
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
resume=true
|
||||
```
|
||||
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
|
||||
|
||||
---
|
||||
|
||||
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
|
||||
|
||||
Happy coding! 🤗
|
||||
52
examples/6_add_image_transforms.py
Normal file
52
examples/6_add_image_transforms.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the original frame
|
||||
to_pil = ToPILImage()
|
||||
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
|
||||
|
||||
# Save the transformed frame
|
||||
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
|
||||
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
|
||||
87
examples/advanced/1_train_act_pusht/act_pusht.yaml
Normal file
87
examples/advanced/1_train_act_pusht/act_pusht.yaml
Normal file
@@ -0,0 +1,87 @@
|
||||
# @package _global_
|
||||
|
||||
# Change the seed to match what PushT eval uses
|
||||
# (to avoid evaluating on seeds used for generating the training data).
|
||||
seed: 100000
|
||||
# Change the dataset repository to the PushT one.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
# Use min_max normalization just because it's more standard.
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
# Use min_max normalization just because it's more standard.
|
||||
action: min_max
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
70
examples/advanced/1_train_act_pusht/train_act_pusht.md
Normal file
70
examples/advanced/1_train_act_pusht/train_act_pusht.md
Normal file
@@ -0,0 +1,70 @@
|
||||
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
|
||||
|
||||
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
|
||||
|
||||
Let's get started!
|
||||
|
||||
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
|
||||
|
||||
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
|
||||
|
||||
```diff
|
||||
override_dataset_stats:
|
||||
- observation.images.top:
|
||||
+ observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
policy:
|
||||
input_shapes:
|
||||
- observation.images.top: [3, 480, 640]
|
||||
+ observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
input_normalization_modes:
|
||||
- observation.images.top: mean_std
|
||||
+ observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
```
|
||||
|
||||
Here we've accounted for the following:
|
||||
- PushT uses "observation.image" for its image key.
|
||||
- PushT provides smaller images.
|
||||
|
||||
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
|
||||
|
||||
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
|
||||
|
||||
```bash
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
|
||||
```
|
||||
|
||||
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
|
||||
|
||||
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act_pusht env=pusht
|
||||
```
|
||||
|
||||
Notice that this is much the same as the command that failed at the start of the tutorial, only:
|
||||
- Now we are using `policy=act_pusht` to point to our new configuration file.
|
||||
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
|
||||
|
||||
Hurrah! You're now training ACT for the PushT environment.
|
||||
|
||||
---
|
||||
|
||||
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -45,6 +45,9 @@ import itertools
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
||||
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
||||
# a yaml file AND a environment name. The difference should be more obvious.
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"AlohaInsertion-v0",
|
||||
@@ -52,6 +55,7 @@ available_tasks_per_env = {
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -77,6 +81,23 @@ available_datasets_per_env = {
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -108,16 +129,19 @@ available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
]
|
||||
|
||||
# keys and values refer to yaml files
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
||||
@@ -16,17 +16,15 @@
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
@@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(
|
||||
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
|
||||
):
|
||||
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
@@ -159,3 +156,54 @@ def compute_stats(
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
@@ -16,34 +16,94 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
|
||||
@@ -13,12 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
@@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -50,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
@@ -147,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
@@ -164,14 +169,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str,
|
||||
repo_id: str = "from_preloaded",
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
@@ -183,18 +188,214 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
):
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||
|
||||
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||
on the filesystem or uploading to the hub.
|
||||
|
||||
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||
meaningless depending on the downstream usage of the return dataset.
|
||||
"""
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.transform = transform
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info
|
||||
obj.info = info if info is not None else {}
|
||||
obj.videos_dir = videos_dir
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
version=version,
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
||||
# consistency requirements in future iterations of this class.
|
||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||
if dataset.info != self._datasets[0].info:
|
||||
raise ValueError(
|
||||
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||
"not yet supported."
|
||||
)
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_data_keys = set()
|
||||
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
||||
for dataset in self._datasets:
|
||||
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
||||
if len(intersection_data_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. The "
|
||||
"multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def repo_index_to_id(self):
|
||||
"""Return the inverse mapping if repo_id_to_index."""
|
||||
return {v: k for k, v in self.repo_id_to_index}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_samples for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -14,156 +14,119 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
This file contains download scripts for raw datasets.
|
||||
|
||||
Example of usage:
|
||||
```
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||
--raw-dir data/cadene/pusht_raw \
|
||||
--repo-id cadene/pusht_raw
|
||||
```
|
||||
"""
|
||||
|
||||
import io
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "aloha" in dataset_id or "image" in dataset_id:
|
||||
download_hub(raw_dir, dataset_id)
|
||||
elif "pusht" in dataset_id:
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
download_xarm(raw_dir)
|
||||
elif "umi" in dataset_id:
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
def download_raw(raw_dir: Path, repo_id: str):
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
||||
)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
|
||||
|
||||
def download_pusht(raw_dir: str):
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
if not dataset_id.endswith("_raw"):
|
||||
warnings.warn(
|
||||
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
||||
shutil.rmtree(raw_dir / "pusht")
|
||||
|
||||
|
||||
def download_xarm(raw_dir: Path):
|
||||
"""Download all xarm datasets at once"""
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for pkl_path in zip_f.namelist():
|
||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
||||
zip_f.extract(member=pkl_path)
|
||||
# move to corresponding raw directory
|
||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
||||
shutil.move(pkl_path, raw_pkl_path)
|
||||
shutil.rmtree(extract_dir)
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
def download_hub(raw_dir: Path, dataset_id: str):
|
||||
raw_dir = Path(raw_dir)
|
||||
# Send warning if raw_dir isn't well formated
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
|
||||
|
||||
def download_umi(raw_dir: Path):
|
||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
def download_all_raw_datasets():
|
||||
data_dir = Path("data")
|
||||
repo_ids = [
|
||||
"cadene/pusht_image_raw",
|
||||
"cadene/xarm_lift_medium_image_raw",
|
||||
"cadene/xarm_lift_medium_replay_image_raw",
|
||||
"cadene/xarm_push_medium_image_raw",
|
||||
"cadene/xarm_push_medium_replay_image_raw",
|
||||
"cadene/aloha_sim_insertion_human_image_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
||||
"cadene/pusht_raw",
|
||||
"cadene/xarm_lift_medium_raw",
|
||||
"cadene/xarm_lift_medium_replay_raw",
|
||||
"cadene/xarm_push_medium_raw",
|
||||
"cadene/xarm_push_medium_replay_raw",
|
||||
"cadene/aloha_sim_insertion_human_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
||||
"cadene/aloha_mobile_cabinet_raw",
|
||||
"cadene/aloha_mobile_chair_raw",
|
||||
"cadene/aloha_mobile_elevator_raw",
|
||||
"cadene/aloha_mobile_shrimp_raw",
|
||||
"cadene/aloha_mobile_wash_pan_raw",
|
||||
"cadene/aloha_mobile_wipe_wine_raw",
|
||||
"cadene/aloha_static_battery_raw",
|
||||
"cadene/aloha_static_candy_raw",
|
||||
"cadene/aloha_static_coffee_raw",
|
||||
"cadene/aloha_static_coffee_new_raw",
|
||||
"cadene/aloha_static_cups_open_raw",
|
||||
"cadene/aloha_static_fork_pick_up_raw",
|
||||
"cadene/aloha_static_pingpong_test_raw",
|
||||
"cadene/aloha_static_pro_pencil_raw",
|
||||
"cadene/aloha_static_screw_driver_raw",
|
||||
"cadene/aloha_static_tape_raw",
|
||||
"cadene/aloha_static_thread_velcro_raw",
|
||||
"cadene/aloha_static_towel_raw",
|
||||
"cadene/aloha_static_vinh_cup_raw",
|
||||
"cadene/aloha_static_vinh_cup_left_raw",
|
||||
"cadene/aloha_static_ziploc_slide_raw",
|
||||
"cadene/umi_cup_in_the_wild_raw",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
raw_dir = data_dir / repo_id
|
||||
download_raw(raw_dir, repo_id)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_raw(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht_image",
|
||||
"xarm_lift_medium_image",
|
||||
"xarm_lift_medium_replay_image",
|
||||
"xarm_push_medium_image",
|
||||
"xarm_push_medium_replay_image",
|
||||
"aloha_sim_insertion_human_image",
|
||||
"aloha_sim_insertion_scripted_image",
|
||||
"aloha_sim_transfer_cube_human_image",
|
||||
"aloha_sim_transfer_cube_scripted_image",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_cabinet",
|
||||
"aloha_mobile_chair",
|
||||
"aloha_mobile_elevator",
|
||||
"aloha_mobile_shrimp",
|
||||
"aloha_mobile_wash_pan",
|
||||
"aloha_mobile_wipe_wine",
|
||||
"aloha_static_battery",
|
||||
"aloha_static_candy",
|
||||
"aloha_static_coffee",
|
||||
"aloha_static_coffee_new",
|
||||
"aloha_static_cups_open",
|
||||
"aloha_static_fork_pick_up",
|
||||
"aloha_static_pingpong_test",
|
||||
"aloha_static_pro_pencil",
|
||||
"aloha_static_screw_driver",
|
||||
"aloha_static_tape",
|
||||
"aloha_static_thread_velcro",
|
||||
"aloha_static_towel",
|
||||
"aloha_static_vinh_cup",
|
||||
"aloha_static_vinh_cup_left",
|
||||
"aloha_static_ziploc_slide",
|
||||
"umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
download_raw(raw_dir, dataset_id)
|
||||
main()
|
||||
|
||||
@@ -30,6 +30,7 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
num_episodes = len(hdf5_files)
|
||||
|
||||
id_from = 0
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx in tqdm.tqdm(ep_ids):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
@@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
@@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
assert isinstance(ep_idx, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
gc.collect()
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
@@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dir, video)
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
assert raw_dir.exists()
|
||||
|
||||
leader_file = list(raw_dir.glob("*.parquet"))
|
||||
if len(leader_file) == 0:
|
||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||
|
||||
# Merge all data stream using nearest backward strategy
|
||||
df = reference_df
|
||||
for path in raw_dir.glob("*.parquet"):
|
||||
key = path.stem # action or observation.state or ...
|
||||
if key == reference_key:
|
||||
continue
|
||||
if "failed_episode_index" in key:
|
||||
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
|
||||
continue
|
||||
modality_df = pd.read_parquet(path)
|
||||
modality_df = modality_df[["timestamp_utc", key]]
|
||||
df = pd.merge_asof(
|
||||
df,
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far appart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
|
||||
image_keys = [key for key in df if "observation.images." in key]
|
||||
|
||||
def get_episode_index(row):
|
||||
episode_index_per_cam = {}
|
||||
for key in image_keys:
|
||||
path = row[key][0]["path"]
|
||||
match = re.search(r"_(\d{6}).mp4", path)
|
||||
if not match:
|
||||
raise ValueError(path)
|
||||
episode_index = int(match.group(1))
|
||||
episode_index_per_cam[key] = episode_index
|
||||
if len(set(episode_index_per_cam.values())) != 1:
|
||||
raise ValueError(
|
||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||
)
|
||||
return episode_index
|
||||
|
||||
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
||||
|
||||
# dora only use arrays, so single values are encapsulated into a list
|
||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
||||
df = df.reset_index()
|
||||
df["index"] = df.index
|
||||
|
||||
# set 'next.done' to True for the last frame of each episode
|
||||
df["next.done"] = False
|
||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check
|
||||
has_nan = df.isna().any().any()
|
||||
if has_nan:
|
||||
raise ValueError("Dataset contains Nan values.")
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
for ep_idx in ep_ids:
|
||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
data_dict = {}
|
||||
for key in df:
|
||||
# is video frame
|
||||
if "observation.images." in key:
|
||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
# is number
|
||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -53,7 +54,7 @@ def check_format(raw_dir):
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
@@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
states = torch.from_numpy(zarr_data["state"])
|
||||
actions = torch.from_numpy(zarr_data["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in zarr_data.meta["episode_ends"]:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# sanity check
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
image = imgs[id_from:id_to]
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[id_from:id_to]
|
||||
state = states[from_idx:to_idx]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
@@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
@@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
ep_dict["action"] = actions[id_from:id_to]
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
@@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
@@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
||||
@@ -19,7 +19,6 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
@@ -29,6 +28,7 @@ from PIL import Image as PILImage
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
@@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in episode_ends:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# sanity heck
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
state = states[id_from:id_to]
|
||||
state = states[from_idx:to_idx]
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
@@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = id_from
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
return data_dict, episode_data_index
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
@@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
@@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
||||
@@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -54,37 +55,42 @@ def check_format(raw_dir):
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
pkl_data = pickle.load(f)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
ep_idx = 0
|
||||
total_frames = pkl_data["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not pkl_data["dones"][i]:
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx, to_idx = 0, 0
|
||||
for done in pkl_data["dones"]:
|
||||
to_idx += 1
|
||||
if not done:
|
||||
continue
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_frames = id_to - id_from
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
||||
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
@@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
@@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
ep_idx += 1
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
@@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
|
||||
61
lerobot/common/datasets/sampler.py
Normal file
61
lerobot/common/datasets/sampler.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterator, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
self,
|
||||
episode_data_index: dict,
|
||||
episode_indices_to_use: Union[list, None] = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
"""
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
197
lerobot/common/datasets/transforms.py
Normal file
197
lerobot/common/datasets/transforms.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
Args:
|
||||
transforms: list of transformations.
|
||||
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
|
||||
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
|
||||
have the same probability.
|
||||
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
|
||||
Must be in [1, len(transforms)].
|
||||
random_order: apply transformations in a random order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: Sequence[Callable],
|
||||
p: list[float] | None = None,
|
||||
n_subset: int | None = None,
|
||||
random_order: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(transforms, Sequence):
|
||||
raise TypeError("Argument transforms should be a sequence of callables")
|
||||
if p is None:
|
||||
p = [1] * len(transforms)
|
||||
elif len(p) != len(transforms):
|
||||
raise ValueError(
|
||||
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
||||
)
|
||||
|
||||
if n_subset is None:
|
||||
n_subset = len(transforms)
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
self.p = [prob / total for prob in p]
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"transforms={self.transforms}, "
|
||||
f"p={self.p}, "
|
||||
f"n_subset={self.n_subset}, "
|
||||
f"random_order={self.random_order}"
|
||||
)
|
||||
|
||||
|
||||
class SharpnessJitter(Transform):
|
||||
"""Randomly change the sharpness of an image or video.
|
||||
|
||||
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
|
||||
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
|
||||
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
|
||||
augmentations as a result.
|
||||
|
||||
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
|
||||
by a factor of 2.
|
||||
|
||||
If the input is a :class:`torch.Tensor`,
|
||||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
|
||||
Args:
|
||||
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
|
||||
[max(0, 1 - sharpness), 1 + sharpness] or the given
|
||||
[min, max]. Should be non negative numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, sharpness: float | Sequence[float]) -> None:
|
||||
super().__init__()
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
@@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
|
||||
return outdict
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict):
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
@@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
@@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||
|
||||
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""
|
||||
Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||
@@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
|
||||
@@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
@@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -13,25 +13,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
@@ -42,22 +50,54 @@ def cfg_to_group(cfg, return_list=False):
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb."""
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
def __init__(self, log_dir, job_name, cfg):
|
||||
self._log_dir = Path(log_dir)
|
||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._job_name = job_name
|
||||
self._model_dir = self._log_dir / "checkpoints"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.training.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.training.get("save_buffer", False)
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
@@ -69,75 +109,135 @@ class Logger:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=job_name,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
# group=self._group,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self._log_dir,
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
# TODO(rcadene): add resume option
|
||||
resume=None,
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy: Policy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_dir = self._model_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
def save_buffer(self, buffer, identifier):
|
||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||
buffer.save(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="buffer",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def finish(self, agent, buffer):
|
||||
if self._save_model:
|
||||
self.save_model(agent, identifier="final")
|
||||
if self._save_buffer:
|
||||
self.save_buffer(buffer, identifier="buffer")
|
||||
if self._wandb:
|
||||
self._wandb.finish()
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpont(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
@@ -25,6 +25,13 @@ class ACTConfig:
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
@@ -33,15 +40,15 @@ class ACTConfig:
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
|
||||
@@ -139,25 +139,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
out_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
@@ -198,27 +199,31 @@ class ACT(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_input_state = "observation.state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
if self.use_input_state:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
)
|
||||
self.latent_dim = config.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_input_state:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
@@ -238,15 +243,17 @@ class ACT(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||
if self.use_input_state:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
@@ -285,7 +292,7 @@ class ACT(nn.Module):
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
batch_size = batch["observation.images"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -293,11 +300,16 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
||||
1
|
||||
) # (B, 1, D)
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||
|
||||
if self.use_input_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
|
||||
# Prepare fixed positional embedding.
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
@@ -308,16 +320,17 @@ class ACT(nn.Module):
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
@@ -326,8 +339,10 @@ class ACT(nn.Module):
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
@@ -337,13 +352,15 @@ class ACT(nn.Module):
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
|
||||
# Get positional embeddings for robot state and latent.
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||
torch.stack(encoder_in_feats, axis=0),
|
||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||
]
|
||||
)
|
||||
@@ -357,6 +374,7 @@ class ACT(nn.Module):
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
|
||||
@@ -26,21 +26,26 @@ class DiffusionConfig:
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- A key starting with "observation.image is required as an input.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
@@ -155,7 +160,7 @@ class DiffusionConfig:
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if (
|
||||
if self.crop_shape is not None and (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
|
||||
@@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# run sampling
|
||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
@@ -304,7 +302,11 @@ class DiffusionModel(nn.Module):
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
if self.config.do_mask_loss_for_padding:
|
||||
if "action_is_pad" not in batch:
|
||||
raise ValueError(
|
||||
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
@@ -423,11 +425,15 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape`.
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
|
||||
@@ -147,7 +147,7 @@ class Normalize(nn.Module):
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
|
||||
@@ -57,7 +57,7 @@ class Policy(Protocol):
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
|
||||
@@ -31,6 +31,15 @@ class TDMPCConfig:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
|
||||
@@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
@@ -19,7 +19,7 @@ import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
@@ -48,12 +48,38 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
return device
|
||||
|
||||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -69,16 +95,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state = random.getstate()
|
||||
np_random_state = np.random.get_state()
|
||||
torch_random_state = torch.random.get_rng_state()
|
||||
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||
random_state_dict = get_global_random_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(np_random_state)
|
||||
torch.random.set_rng_state(torch_random_state)
|
||||
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
@@ -100,13 +120,13 @@ def init_logging():
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_big_number(num):
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.0f}{suffix}"
|
||||
return f"{num:.{precision}f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
@@ -5,18 +5,33 @@ defaults:
|
||||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
resume: false
|
||||
device: cuda # cpu
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: false
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: ???
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
training:
|
||||
offline_steps: ???
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
@@ -25,7 +40,43 @@ training:
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
# We use a custom RandomSubsetApply container to sample them.
|
||||
# For each transform, the following parameters are available:
|
||||
# weight: This represents the multinomial probability (with no replacement)
|
||||
# used for sampling the transform. If the sum of the weights is not 1,
|
||||
# they will be normalized.
|
||||
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
# (following uniform distribution) when it's applied.
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: false
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number of available transforms].
|
||||
max_num_transforms: 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
contrast:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
@@ -36,7 +87,7 @@ eval:
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
# Set to true to disable saving an artifact despite save_checkpoint == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
10
lerobot/configs/env/aloha.yaml
vendored
10
lerobot/configs/env/aloha.yaml
vendored
@@ -5,10 +5,10 @@ fps: 50
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
Normal file
13
lerobot/configs/env/dora_aloha_real.yaml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
fps: ${fps}
|
||||
11
lerobot/configs/env/pusht.yaml
vendored
11
lerobot/configs/env/pusht.yaml
vendored
@@ -5,10 +5,13 @@ fps: 10
|
||||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 96
|
||||
episode_length: 300
|
||||
fps: ${fps}
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
||||
11
lerobot/configs/env/xarm.yaml
vendored
11
lerobot/configs/env/xarm.yaml
vendored
@@ -5,10 +5,13 @@ fps: 15
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 84
|
||||
episode_length: 25
|
||||
fps: ${fps}
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
||||
@@ -15,7 +15,7 @@ training:
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
@@ -25,7 +25,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
115
lerobot/configs/policy/act_real.yaml
Normal file
115
lerobot/configs/policy/act_real.yaml
Normal file
@@ -0,0 +1,115 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
111
lerobot/configs/policy/act_real_no_state.yaml
Normal file
111
lerobot/configs/policy/act_real_no_state.yaml
Normal file
@@ -0,0 +1,111 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real_no_state \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -27,7 +27,7 @@ training:
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
@@ -44,6 +44,10 @@ training:
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
@@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||
online_steps: 0 # 25000 not implemented yet
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
|
||||
@@ -13,39 +13,71 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Use this script to get a quick summary of your system config.
|
||||
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
|
||||
"""
|
||||
|
||||
import platform
|
||||
|
||||
import huggingface_hub
|
||||
HAS_HF_HUB = True
|
||||
HAS_HF_DATASETS = True
|
||||
HAS_NP = True
|
||||
HAS_TORCH = True
|
||||
HAS_LEROBOT = True
|
||||
|
||||
# import dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import huggingface_hub
|
||||
except ImportError:
|
||||
HAS_HF_HUB = False
|
||||
|
||||
from lerobot import __version__ as version
|
||||
try:
|
||||
import datasets
|
||||
except ImportError:
|
||||
HAS_HF_DATASETS = False
|
||||
|
||||
pt_version = torch.__version__
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
HAS_NP = False
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import lerobot
|
||||
except ImportError:
|
||||
HAS_LEROBOT = False
|
||||
|
||||
|
||||
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
|
||||
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
||||
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
||||
np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
def display_sys_info() -> dict:
|
||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||
info = {
|
||||
"`lerobot` version": version,
|
||||
"`lerobot` version": lerobot_version,
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface_hub version": huggingface_hub.__version__,
|
||||
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
|
||||
# "Dataset version": dataset.__version__,
|
||||
"Numpy version": np.__version__,
|
||||
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
||||
"Huggingface_hub version": hf_hub_version,
|
||||
"Dataset version": hf_datasets_version,
|
||||
"Numpy version": np_version,
|
||||
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
||||
"Cuda version": cuda_version,
|
||||
"Using GPU in script?": "<fill in>",
|
||||
"Using distributed or parallel set-up in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ OR, you want to evaluate a model checkpoint from the LeRobot training script for
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
-p outputs/train/diffusion_pusht/checkpoints/005000 \
|
||||
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
@@ -46,6 +46,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
@@ -60,7 +61,7 @@ from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
@@ -98,13 +99,13 @@ def rollout(
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||
specifies the seeds for each of the environments.
|
||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||
@@ -115,6 +116,7 @@ def rollout(
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
@@ -208,7 +210,7 @@ def eval_policy(
|
||||
policy: torch.nn.Module,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
@@ -220,7 +222,7 @@ def eval_policy(
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||
video_dir: Where to save rendered videos.
|
||||
videos_dir: Where to save rendered videos.
|
||||
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
@@ -230,6 +232,10 @@ def eval_policy(
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
assert isinstance(policy, Policy)
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
@@ -270,11 +276,16 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
seeds=seeds,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
@@ -284,7 +295,8 @@ def eval_policy(
|
||||
# this won't be included).
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
@@ -295,8 +307,12 @@ def eval_policy(
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
all_seeds.extend(seeds)
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
@@ -346,8 +362,9 @@ def eval_policy(
|
||||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
@@ -502,25 +519,23 @@ def _compile_episode_data(
|
||||
}
|
||||
|
||||
|
||||
def eval(
|
||||
pretrained_policy_path: str | None = None,
|
||||
def main(
|
||||
pretrained_policy_path: Path | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if hydra_cfg_path is None:
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -533,22 +548,25 @@ def eval(
|
||||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
@@ -584,6 +602,13 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
@@ -592,7 +617,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
@@ -616,4 +641,8 @@ if __name__ == "__main__":
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
|
||||
@@ -18,81 +18,69 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example:
|
||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-dir data/pusht_raw \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-dir data/xarm_lift_medium_raw \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/xarm_lift_medium
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, create_branch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(raw_format)
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
)
|
||||
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save info
|
||||
@@ -110,7 +98,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
@@ -124,7 +112,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||
)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
"""
|
||||
@@ -140,39 +128,61 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
video: bool,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
debug: bool,
|
||||
raw_dir: Path,
|
||||
raw_format: str,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = True,
|
||||
local_dir: Path | None = None,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||
)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||
)
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
if local_dir:
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
# Send warning if local_dir isn't well formated
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
# Check we don't override an existing `local_dir` by mistake
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
else:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
else:
|
||||
# Temporary directory used to store images, videos, meta_data
|
||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||
videos_dir = Path(cache_dir) / "videos"
|
||||
|
||||
# Download the raw dataset if available
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
@@ -181,14 +191,14 @@ def push_dataset_to_hub(
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
@@ -196,102 +206,80 @@ def push_dataset_to_hub(
|
||||
)
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if save_to_disk:
|
||||
if local_dir:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
if push_to_hub or local_dir:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if not dry_run:
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
if save_tests_to_disk:
|
||||
if tests_data_dir:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.video_frame_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
if local_dir is None:
|
||||
# clear cache
|
||||
shutil.rmtree(meta_data_dir)
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
# TODO(rcadene): add automatic detection of the format
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--community-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=CODEBASE_VERSION,
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
@@ -317,10 +305,21 @@ def main():
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Debug mode process the first episode only.",
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -15,25 +15,31 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from omegaconf import DictConfig
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
@@ -69,7 +75,6 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
@@ -87,21 +92,40 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
grad_clip_norm,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.time()
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"].mean()
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if lr_scheduler is not None:
|
||||
@@ -115,36 +139,19 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
update_s = info["update_s"]
|
||||
dataloading_s = info["dataloading_s"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
@@ -165,6 +172,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
f"lr:{lr:0.1e}",
|
||||
# in seconds
|
||||
f"updt_s:{update_s:.3f}",
|
||||
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
@@ -211,103 +219,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
||||
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||
"""
|
||||
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
||||
|
||||
Parameters:
|
||||
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
||||
- n_on (int): Number of online samples.
|
||||
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
||||
|
||||
The total weight of offline samples is n_off * 1.0.
|
||||
The total weight of offline samples is n_on * w.
|
||||
The total combined weight of all samples is n_off + n_on * w.
|
||||
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
||||
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
||||
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
"""
|
||||
assert 0.0 <= pc_on <= 1.0
|
||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
||||
dataset's structure and adjusting the sampling strategy based on the specified
|
||||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
||||
# sanity check
|
||||
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
||||
assert first_index == 0, f"{first_index=} is not 0"
|
||||
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
||||
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
online_dataset.episode_data_index = episode_data_index
|
||||
else:
|
||||
# get the starting indices of the new episodes and frames to be added
|
||||
start_episode_idx = last_episode_idx + 1
|
||||
start_index = last_index + 1
|
||||
|
||||
def shift_indices(episode_index, index):
|
||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
||||
return example
|
||||
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
||||
enable_progress_bars()
|
||||
|
||||
episode_data_index["from"] += start_index
|
||||
episode_data_index["to"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
||||
len_online = len(online_dataset)
|
||||
len_offline = len(concat_dataset) - len_online
|
||||
weight_offline = 1.0
|
||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
||||
|
||||
# update the total number of samples used during sampling
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
@@ -316,35 +227,97 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
# Log a warning about differences between the checkpoint configuration and the provided
|
||||
# configuration.
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
|
||||
)
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
if cfg.training.online_steps > 0:
|
||||
raise NotImplementedError("Online training is not implemented yet.")
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||
)
|
||||
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
@@ -356,60 +329,87 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_model(
|
||||
logger.save_checkpont(
|
||||
step,
|
||||
policy,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=step_identifier,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.training.get("drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.training.offline_steps):
|
||||
if offline_step == 0:
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
@@ -417,79 +417,28 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
step += 1
|
||||
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
online_training_env = make_env(cfg, n_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.hf_dataset = {}
|
||||
online_dataset.episode_data_index = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
weights = [1.0] * len(concat_dataset)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
weights, num_samples=len(concat_dataset), replacement=True
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.training.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
online_step = 0
|
||||
is_offline = False
|
||||
for env_step in range(cfg.training.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_training_env,
|
||||
policy,
|
||||
n_episodes=1,
|
||||
return_episode_data=True,
|
||||
start_seed=cfg.training.online_env_seed,
|
||||
enable_progbar=True,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
eval_env.close()
|
||||
online_training_env.close()
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
@@ -66,28 +66,31 @@ import gc
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
@@ -106,6 +109,7 @@ def visualize_dataset(
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
root: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
@@ -113,7 +117,7 @@ def visualize_dataset(
|
||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
@@ -224,7 +228,8 @@ def main():
|
||||
help=(
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -245,8 +250,8 @@ def main():
|
||||
default=0,
|
||||
help=(
|
||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||
"It also deactivates the spawning of a viewer. ",
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
||||
"It also deactivates the spawning of a viewer. "
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -255,6 +260,12 @@ def main():
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Root directory for a dataset stored on a local machine.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
||||
|
||||
142
lerobot/scripts/visualize_image_transforms.py
Normal file
142
lerobot/scripts/visualize_image_transforms.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize effects of image transforms for a given configuration.
|
||||
|
||||
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
||||
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
||||
|
||||
|
||||
--- Usage Examples ---
|
||||
|
||||
Increase hue jitter
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.hue.min_max=[-0.25,0.25]
|
||||
```
|
||||
|
||||
Increase brightness & brightness weight
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.brightness.weight=10.0 \
|
||||
training.image_transforms.brightness.min_max=[1.0,2.0]
|
||||
```
|
||||
|
||||
Blur images and disable saturation & hue
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.sharpness.weight=10.0 \
|
||||
training.image_transforms.sharpness.min_max=[0.0,1.0] \
|
||||
training.image_transforms.saturation.weight=0.0 \
|
||||
training.image_transforms.hue.weight=0.0
|
||||
```
|
||||
|
||||
Use all transforms with random order
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.max_num_transforms=5 \
|
||||
training.image_transforms.random_order=true
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
N_EXAMPLES = 5
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||
tf = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||
|
||||
print("Combined transforms examples saved to:")
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"saturation",
|
||||
"hue",
|
||||
"sharpness",
|
||||
]
|
||||
print("Individual transforms examples saved to:")
|
||||
for transform in transforms:
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
output_dir_single = output_dir / f"{transform}"
|
||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||
|
||||
print(f" {output_dir_single}")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms(cfg):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
|
||||
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms()
|
||||
932
poetry.lock
generated
932
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -41,11 +41,12 @@ numba = ">=0.59.0"
|
||||
torch = "^2.2.1"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = "^0.27.2"
|
||||
torchvision = ">=0.18.0"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
@@ -58,9 +59,11 @@ imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
dora = ["gym-dora"]
|
||||
pusht = ["gym-pusht"]
|
||||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
|
||||
size 3686488
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
|
||||
size 40551392
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
||||
size 5104
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
||||
size 31688
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
||||
size 68
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
||||
size 34928
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
||||
size 5104
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
||||
size 30808
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
|
||||
size 68
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
||||
size 33608
|
||||
86
tests/scripts/save_image_transforms_to_safetensors.py
Normal file
86
tests/scripts/save_image_transforms_to_safetensors.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
|
||||
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = {
|
||||
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform, values in transforms.items():
|
||||
for min_max in values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
frames[key] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
batch = next(iter(dataloader))
|
||||
obs = {
|
||||
k: batch[k]
|
||||
for k in batch
|
||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
||||
}
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
actions_queue = cfg.policy.n_action_steps
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
|
||||
actions_queue = (
|
||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
||||
)
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
@@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
]
|
||||
for env, policy, extra_overrides in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||
@@ -16,6 +16,7 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
@@ -25,26 +26,34 @@ from datasets import Dataset
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
compute_stats,
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
Tests that:
|
||||
- we can create a dataset with the factory.
|
||||
- for a commonly used set of data keys, the data dimensions are correct.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
@@ -105,6 +114,39 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
@@ -315,3 +357,31 @@ def test_backward_compatibility(repo_id):
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
def test_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
data_a = torch.rand(30, dtype=torch.float32)
|
||||
data_b = torch.rand(20, dtype=torch.float32)
|
||||
data_c = torch.rand(20, dtype=torch.float32)
|
||||
|
||||
hf_dataset_1 = Dataset.from_dict(
|
||||
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
||||
@@ -45,11 +45,11 @@ def test_example_1():
|
||||
|
||||
|
||||
@require_package("gym_pusht")
|
||||
def test_examples_2_through_4():
|
||||
def test_examples_basic2_basic3_advanced1():
|
||||
"""
|
||||
Train a model with example 3, check the outputs.
|
||||
Evaluate the trained model with example 2, check the outputs.
|
||||
Calculate the validation loss with example 4, check the outputs.
|
||||
Calculate the validation loss with advanced example 1, check the outputs.
|
||||
"""
|
||||
|
||||
### Test example 3
|
||||
@@ -97,7 +97,7 @@ def test_examples_2_through_4():
|
||||
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
|
||||
|
||||
## Test example 4
|
||||
file_contents = _read_file("examples/4_calculate_validation_loss.py")
|
||||
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py")
|
||||
|
||||
# Run on a single example from the last episode, use CPU, and use the local model.
|
||||
file_contents = _find_and_replace(
|
||||
|
||||
260
tests/test_image_transforms.py
Normal file
260
tests/test_image_transforms.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID)
|
||||
return dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_random():
|
||||
return torch.rand(3, 480, 640)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
v2.ColorJitter(brightness=0.5),
|
||||
v2.ColorJitter(contrast=0.5),
|
||||
v2.ColorJitter(saturation=0.5),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img):
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img), img)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img, min_max):
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img, min_max):
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img, min_max):
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img, min_max):
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img, min_max):
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img):
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
)
|
||||
tf_expected = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img):
|
||||
out_imgs = []
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=True,
|
||||
)
|
||||
with seeded_context(1337):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img))
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
[
|
||||
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
actual = tf(img)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img, default_transforms):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img)
|
||||
|
||||
expected = default_transforms["default"]
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(p, img):
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img)
|
||||
|
||||
p_horz, _ = p
|
||||
if p_horz:
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
||||
else:
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img):
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||
actual = random_order(img)
|
||||
expected = v2.Compose(flips)(img)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_subset", [0, 5])
|
||||
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||
with pytest.raises(ValueError):
|
||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img):
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img):
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_min_negative():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((-0.1, 2.0))
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
@@ -84,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
@@ -135,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
@@ -291,6 +296,8 @@ def test_normalize(insert_temporal_dim):
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
|
||||
352
tests/test_push_dataset_to_hub.py
Normal file
352
tests/test_push_dataset_to_hub.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
||||
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
||||
we skip them for now in our CI.
|
||||
|
||||
Example to run backward compatiblity tests locally:
|
||||
```
|
||||
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
from tests.utils import require_package_arg
|
||||
|
||||
|
||||
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/img",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/camera0_rgb",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_end_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_start_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_gripper_width",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||
import pickle
|
||||
|
||||
dataset_dict = {
|
||||
"observations": {
|
||||
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||
"state": np.random.randn(num_frames, 4),
|
||||
},
|
||||
"actions": np.random.randn(num_frames, 3),
|
||||
"rewards": np.random.randn(num_frames),
|
||||
"masks": np.random.randn(num_frames),
|
||||
"dones": np.array([False, True, True, True]),
|
||||
}
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
with open(pkl_path, "wb") as f:
|
||||
pickle.dump(dataset_dict, f)
|
||||
|
||||
|
||||
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||
import h5py
|
||||
|
||||
for ep_idx in range(num_episodes):
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||
with h5py.File(str(path_h5), "w") as f:
|
||||
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset(
|
||||
"observations/images/top",
|
||||
data=np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pandas
|
||||
|
||||
def write_parquet(key, timestamps, values):
|
||||
data = {
|
||||
"timestamp_utc": timestamps,
|
||||
key: values,
|
||||
}
|
||||
df = pandas.DataFrame(data)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
||||
|
||||
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
||||
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
||||
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
||||
|
||||
cam_key = "observation.images.cam_high"
|
||||
timestamps = []
|
||||
actions = []
|
||||
states = []
|
||||
frames = []
|
||||
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
||||
for i, frame_idx in enumerate(frame_indices):
|
||||
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
||||
action = np.random.randn(21).tolist()
|
||||
state = np.random.randn(21).tolist()
|
||||
ep_idx = episode_indices_mapping[i]
|
||||
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||
timestamps.append(t_utc)
|
||||
actions.append(action)
|
||||
states.append(state)
|
||||
frames.append(frame)
|
||||
|
||||
write_parquet(cam_key, timestamps, frames)
|
||||
write_parquet("observation.state", timestamps, states)
|
||||
write_parquet("action", timestamps, actions)
|
||||
write_parquet("episode_index", timestamps, episode_indices)
|
||||
|
||||
# write fake mp4 file for each episode
|
||||
for ep_idx in range(num_episodes):
|
||||
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||
|
||||
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = raw_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
|
||||
def _mock_download_raw(raw_dir, repo_id):
|
||||
if "wrist_gripper" in repo_id:
|
||||
_mock_download_raw_dora(raw_dir)
|
||||
elif "aloha" in repo_id:
|
||||
_mock_download_raw_aloha(raw_dir)
|
||||
elif "pusht" in repo_id:
|
||||
_mock_download_raw_pusht(raw_dir)
|
||||
elif "xarm" in repo_id:
|
||||
_mock_download_raw_xarm(raw_dir)
|
||||
elif "umi" in repo_id:
|
||||
_mock_download_raw_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(repo_id)
|
||||
|
||||
|
||||
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||
|
||||
|
||||
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
tmpdir = Path(tmpdir)
|
||||
out_dir = tmpdir / "out"
|
||||
raw_dir = tmpdir / "raw"
|
||||
# mkdir to skip download
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format="some_format",
|
||||
repo_id="user/dataset",
|
||||
local_dir=out_dir,
|
||||
force_override=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id",
|
||||
[
|
||||
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
raw_dir = tmpdir / f"{repo_id}_raw"
|
||||
_mock_download_raw(raw_dir, repo_id)
|
||||
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
lerobot_dataset = push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
)
|
||||
|
||||
# minimal generic tests on the local directory containing LeRobotDataset
|
||||
assert (local_dir / "meta_data" / "info.json").exists()
|
||||
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
||||
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
||||
for i in range(num_episodes):
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
||||
assert (local_dir / "train" / "dataset_info.json").exists()
|
||||
assert (local_dir / "train" / "state.json").exists()
|
||||
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
||||
|
||||
# minimal generic tests on the item
|
||||
item = lerobot_dataset[0]
|
||||
assert "index" in item
|
||||
assert "episode_index" in item
|
||||
assert "timestamp" in item
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert cam_key in item
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
[
|
||||
# TODO(rcadene): add raw dataset test artifacts
|
||||
("pusht_zarr", "lerobot/pusht"),
|
||||
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
("dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
tmpdir = Path(tmpdir)
|
||||
raw_dir = tmpdir / f"{dataset_id}_raw"
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
episodes=[0],
|
||||
)
|
||||
|
||||
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
||||
ds_reference = LeRobotDataset(repo_id)
|
||||
|
||||
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
||||
|
||||
def check_same_items(item1, item2):
|
||||
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||
|
||||
for key in item1:
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||
else:
|
||||
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||
|
||||
for i in range(len(ds_reference.hf_dataset)):
|
||||
item_reference = ds_reference.hf_dataset[i]
|
||||
item_actual = ds_actual.hf_dataset[i]
|
||||
check_same_items(item_reference, item_actual)
|
||||
90
tests/test_sampler.py
Normal file
90
tests/test_sampler.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def test_drop_n_first_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
|
||||
|
||||
def test_drop_n_last_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
|
||||
|
||||
def test_episode_indices_to_use():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
|
||||
|
||||
def test_shuffle():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
@@ -11,22 +11,24 @@ from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rand_fn",
|
||||
(
|
||||
[
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
+ [lambda: torch.rand(1, device="cuda")]
|
||||
if torch.cuda.is_available()
|
||||
else []
|
||||
),
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
# Random generation functions for testing the seeding and random state get/set.
|
||||
rand_fns = [
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||
def test_seeding(rand_fn: Callable[[], int]):
|
||||
set_global_seed(0)
|
||||
a = rand_fn()
|
||||
@@ -46,6 +48,15 @@ def test_seeding(rand_fn: Callable[[], int]):
|
||||
assert c_ == c
|
||||
|
||||
|
||||
def test_get_set_random_state():
|
||||
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||
random_state_dict = get_global_random_state()
|
||||
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||
set_global_random_state(random_state_dict)
|
||||
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||
assert rand_numbers_ == rand_numbers
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
@@ -23,11 +25,27 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, repo_id):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_indices=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
@@ -76,6 +76,7 @@ def require_env(func):
|
||||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||
If 'env_name' is None, this check is skipped.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@@ -91,7 +92,7 @@ def require_env(func):
|
||||
|
||||
# Perform the package check
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
if env_name is not None and not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
@@ -99,6 +100,38 @@ def require_env(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_package_arg(func):
|
||||
"""
|
||||
Decorator that skips the test if the required package is not installed.
|
||||
This is similar to `require_env` but more general in that it can check any package (not just environments).
|
||||
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
|
||||
If 'required_packages' is None, this check is skipped.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Determine if 'required_packages' is provided and extract its value
|
||||
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||
if "required_packages" in arg_names:
|
||||
# Get the index of 'required_packages' and retrieve the value from args
|
||||
index = arg_names.index("required_packages")
|
||||
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
|
||||
else:
|
||||
raise ValueError("Function does not have 'required_packages' as an argument.")
|
||||
|
||||
if required_packages is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Perform the package check
|
||||
for package in required_packages:
|
||||
if not is_package_available(package):
|
||||
pytest.skip(f"{package} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_package(package_name):
|
||||
"""
|
||||
Decorator that skips the test if the specified package is not installed.
|
||||
|
||||
Reference in New Issue
Block a user