Compare commits

...

49 Commits

Author SHA1 Message Date
github-actions[bot]
eda63b0749 Apply style fixes 2025-02-27 15:31:53 +00:00
Quentin Gallouédec
cca2aaa84e Break style to test style bot 2025-02-27 16:31:07 +01:00
Mishig
800c4a847f [Vizualisation] independent column names (#783) 2025-02-27 14:47:18 +01:00
Simon Alibert
bba8c4c0d4 Fix pr_style bot (#782) 2025-02-27 13:09:12 +01:00
Simon Alibert
68b369e321 Fix pr_style_bot (#781) 2025-02-27 12:13:36 +01:00
Mishig
8d60ac3ffc [vizualisation] Add pagination for many episodes (#776) 2025-02-26 19:23:37 +01:00
Simon Alibert
659ec4434d Fix nightly (#775) 2025-02-26 16:36:03 +01:00
Simon Alibert
da265ca920 Add pr style bot (#772) 2025-02-25 23:52:25 +01:00
Simon Alibert
a1809ad3de Add typos checks (#770) 2025-02-25 23:51:15 +01:00
Jannik Grothusen
8699a28be0 [QOL] Enable teleoperation during environment reset (#725) 2025-02-25 19:28:26 +01:00
Raul Garreta
65db5afe1c fixes in 7_get_started_with_real_robot.md (#677) 2025-02-25 19:03:29 +01:00
Youssef Bayouli
75d5fa4604 Optimizing Dockerfile (#751) 2025-02-25 18:42:35 +01:00
Yongjin Cho
e64fad2224 Fix the URL to setup hardware Aloha Stationary in the example document (#766) 2025-02-25 18:33:32 +01:00
Haskely
eecf32e77a feat: Add root directory option for dataset configuration (#765)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-25 17:27:36 +01:00
Simon Alibert
3354d919fc LeRobotDataset v2.1 (#711)
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
2025-02-25 15:27:29 +01:00
Pepijn
aca464ca72 Add mobile so100 (#724) 2025-02-25 09:06:50 +01:00
Simon Alibert
fe483b1d0d Remove poetry.lock (#737)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2025-02-17 12:03:16 +01:00
Simon Alibert
ddeade077e Conform pyproject to PEP 621 (#621) 2025-02-16 14:28:03 +01:00
Simon Alibert
c4c2ce04e7 Update pre-commits (#733) 2025-02-15 15:51:17 +01:00
Simon Alibert
2cb0bf5d41 Add zizmor pre-commit (#732) 2025-02-15 15:50:10 +01:00
Simon Alibert
b86a2c0b47 Fix wandb logging (#730) 2025-02-14 18:00:12 +01:00
Ilia Larchenko
c574eb4984 Fixed eval.py on MPS (#702) 2025-02-14 00:03:55 +01:00
Simon Alibert
1e49cc4d60 Prevent resuming from hub (#726) 2025-02-13 17:15:55 +01:00
Simon Alibert
e71095960f Fixes following #670 (#719) 2025-02-12 12:53:55 +01:00
Simon Alibert
90e099b39f Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2025-02-11 10:36:06 +01:00
Simon Alibert
334deb985d Update CI trigger rules (#712) 2025-02-10 17:22:15 +01:00
Simon Alibert
8548a87bd4 Remove dataset tests artifacts (#701) 2025-02-09 14:24:01 +01:00
Remi
638d411cd3 Add Pi0 (#681)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
2025-02-04 18:01:04 +01:00
Pepijn
dd974529cf User/pepijn/2025 01 31 improved tutorial so100 (#666) 2025-02-03 18:27:55 +01:00
Simon Alibert
43e079f73e Fix nightly tests docker images (#675) 2025-02-02 13:59:33 +01:00
Simon Alibert
6674e36824 Fix Docker cpu/gpu builds (#667) 2025-02-01 12:06:11 +01:00
Pepijn
ae9605f03c fix setting motor id with new dataclass config (#668) 2025-01-31 20:48:46 +01:00
Simon Alibert
3c0a209f9f Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
2025-01-31 13:57:37 +01:00
Simon Alibert
1ee1acf8ad Comply with torchvision 0.21 custom transforms (#665) 2025-01-30 22:06:11 +01:00
Thomas Lips
c4d912a241 Check for "/" in feature names (#660) 2025-01-29 21:54:49 +01:00
Morgan Redfield
4323bdce22 updating config instructions for koch 1v1 motors (#658) 2025-01-28 13:20:33 +01:00
HUANG TZU-CHUN
5daa45436d Fix typos in lerobot/scripts/visualize_dataset.py (#656) 2025-01-28 13:07:10 +01:00
Simon Alibert
4def6d6ac2 Fix cluster image (#653) 2025-01-24 11:25:22 +01:00
Jochen Görtler
d8560b8d5f Bumprerun-sdk dependency to 0.21.0 (#618)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-20 09:50:11 +01:00
Pradeep Kadubandi
380b836eee Fix for the issue https://github.com/huggingface/lerobot/issues/638 (#639) 2025-01-15 10:50:38 +01:00
Philip Fung
eec6796cb8 fixes to SO-100 readme (#600)
Co-authored-by: Philip Fung <no@one>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-10 11:30:01 +01:00
Mishig
25a8597680 [viz] Fixes & updates to html visualizer (#617) 2025-01-09 11:39:54 +01:00
CharlesCNorton
b8b368310c typo fix: batch_convert_dataset_v1_to_v2.py (#615)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:57:45 +01:00
Ville Kuosmanen
5097cd900e fix(visualise): use correct language description for each episode id (#604)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:39:48 +01:00
CharlesCNorton
bc16e1b497 fix(docs): typos in benchmark readme.md (#614)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-01-09 09:35:27 +01:00
Simon Alibert
8f821ecad0 Fix Quality workflow (#622) 2025-01-08 13:35:11 +01:00
CharlesCNorton
4519016e67 Update README.md (#612) 2025-01-03 16:19:37 +01:00
Eugene Mironov
59e2757434 Fix broken create_lerobot_dataset_card (#590) 2024-12-23 15:05:59 +01:00
Mishig
73b64c3089 [vizualizer] for LeRobodDataset V2 (#576) 2024-12-20 16:26:23 +01:00
602 changed files with 13808 additions and 19712 deletions

View File

@@ -8,6 +8,8 @@ on:
schedule: schedule:
- cron: "0 1 * * *" - cron: "0 1 * * *"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@@ -25,11 +27,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@@ -60,11 +65,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@@ -89,9 +97,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3

View File

@@ -7,6 +7,8 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions: {}
# env: # env:
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }} # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs: jobs:

125
.github/workflows/pr_style_bot.yml vendored Normal file
View File

@@ -0,0 +1,125 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
name: PR Style Bot
on:
issue_comment:
types: [created]
permissions:
contents: write
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
- name: Check out PR branch
uses: actions/checkout@v4
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
persist-credentials: true
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: ${PRNUMBER}"
echo "Head Ref: ${HEADREF}"
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
- name: Set up Python
uses: actions/setup-python@v4
- name: Get Ruff Version from pre-commit-config.yaml
id: get-ruff-version
run: |
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
- name: Install Ruff
env:
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
run: python -m pip install "ruff==${RUFF_VERSION}"
- name: Ruff check
run: ruff check --fix
- name: Ruff format
run: ruff format
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:${HEADREF}
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v6
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}

View File

@@ -4,12 +4,12 @@ on:
workflow_dispatch: workflow_dispatch:
workflow_call: workflow_call:
pull_request: pull_request:
branches:
- main
push: push:
branches: branches:
- main - main
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@@ -19,7 +19,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
@@ -30,55 +32,27 @@ jobs:
id: get-ruff-version id: get-ruff-version
run: | run: |
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml) RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
- name: Install Ruff - name: Install Ruff
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}" env:
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
run: python -m pip install "ruff==${RUFF_VERSION}"
- name: Ruff check - name: Ruff check
run: ruff check run: ruff check --output-format=github
- name: Ruff format - name: Ruff format
run: ruff format --diff run: ruff format --diff
typos:
poetry_check: name: Typos
name: Poetry check
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install poetry - name: typos-action
run: pipx install poetry uses: crate-ci/typos@v1.29.10
- name: Poetry check
run: poetry check
poetry_relax:
name: Poetry relax
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
- name: Install poetry-relax
run: poetry self add poetry-relax
- name: Poetry relax
id: poetry_relax
run: |
output=$(poetry relax --check 2>&1)
if echo "$output" | grep -q "Proposing updates"; then
echo "$output"
echo ""
echo "Some dependencies have caret '^' version requirement added by poetry by default."
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
exit 1
else
echo "$output"
fi

View File

@@ -4,12 +4,12 @@ name: Test Dockerfiles
on: on:
pull_request: pull_request:
branches:
- main
paths: paths:
# Run only when DockerFile files are modified # Run only when DockerFile files are modified
- "docker/**" - "docker/**"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@@ -22,6 +22,8 @@ jobs:
steps: steps:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get changed files - name: Get changed files
id: changed-files id: changed-files
@@ -30,21 +32,18 @@ jobs:
files: docker/** files: docker/**
json: "true" json: "true"
- name: Run step if only the files listed above change - name: Run step if only the files listed above change # zizmor: ignore[template-injection]
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
id: set-matrix id: set-matrix
env:
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
run: | run: |
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
build_modified_dockerfiles: build_modified_dockerfiles:
name: Build modified Docker images name: Build modified Docker images
needs: get_changed_files needs: get_changed_files
runs-on: runs-on:
group: aws-general-8-plus group: aws-general-8-plus
if: ${{ needs.get_changed_files.outputs.matrix }} != '' if: needs.get_changed_files.outputs.matrix != ''
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -52,9 +51,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Build Docker image - name: Build Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5

View File

@@ -2,14 +2,13 @@ name: Tests
on: on:
pull_request: pull_request:
branches:
- main
paths: paths:
- "lerobot/**" - "lerobot/**"
- "tests/**" - "tests/**"
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "pyproject.toml"
- ".pre-commit-config.yaml"
- "Makefile" - "Makefile"
- ".cache/**" - ".cache/**"
push: push:
@@ -20,10 +19,16 @@ on:
- "tests/**" - "tests/**"
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "pyproject.toml"
- ".pre-commit-config.yaml"
- "Makefile" - "Makefile"
- ".cache/**" - ".cache/**"
permissions: {}
env:
UV_VERSION: "0.6.0"
jobs: jobs:
pytest: pytest:
name: Pytest name: Pytest
@@ -34,6 +39,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
@@ -41,25 +47,19 @@ jobs:
sudo apt-get update && \ sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install poetry - name: Install uv and python
run: | uses: astral-sh/setup-uv@v5
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10
uses: actions/setup-python@v5
with: with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10" python-version: "3.10"
cache: "poetry"
- name: Install poetry dependencies - name: Install lerobot (all extras)
run: | run: uv sync --all-extras
poetry install --all-extras
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest tests -v --cov=./lerobot --durations=0 \ uv run pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \ -W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \ -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
@@ -74,66 +74,63 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install poetry - name: Install uv and python
run: | uses: astral-sh/setup-uv@v5
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10
uses: actions/setup-python@v5
with: with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10" python-version: "3.10"
- name: Install poetry dependencies - name: Install lerobot
run: | run: uv sync --extra "test"
poetry install --extras "test"
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest tests -v --cov=./lerobot --durations=0 \ uv run pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \ -W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \ -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs && rm -rf tests/outputs outputs
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra end-to-end:
# end-to-end: name: End-to-end
# name: End-to-end runs-on: ubuntu-latest
# runs-on: ubuntu-latest env:
# env: MUJOCO_GL: egl
# MUJOCO_GL: egl steps:
# steps: - uses: actions/checkout@v4
# - uses: actions/checkout@v4 with:
# with: lfs: true # Ensure LFS files are pulled
# lfs: true # Ensure LFS files are pulled persist-credentials: false
# - name: Install apt dependencies - name: Install apt dependencies
# # portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
# run: | run: |
# sudo apt-get update && \ sudo apt-get update && \
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
# - name: Install poetry - name: Install uv and python
# run: | uses: astral-sh/setup-uv@v5
# pipx install poetry && poetry config virtualenvs.in-project true with:
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: "3.10"
# - name: Set up Python 3.10 - name: Install lerobot (all extras)
# uses: actions/setup-python@v5 run: |
# with: uv venv
# python-version: "3.10" uv sync --all-extras
# cache: "poetry"
# - name: Install poetry dependencies - name: venv
# run: | run: |
# poetry install --all-extras echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
# - name: Test end-to-end - name: Test end-to-end
# run: | run: |
# make test-end-to-end \ make test-end-to-end \
# && rm -rf outputs && rm -rf outputs

View File

@@ -3,8 +3,7 @@ on:
name: Secret Leaks name: Secret Leaks
permissions: permissions: {}
contents: read
jobs: jobs:
trufflehog: trufflehog:
@@ -14,6 +13,8 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main
with: with:

4
.gitignore vendored
View File

@@ -49,6 +49,10 @@ share/python-wheels/
*.egg *.egg
MANIFEST MANIFEST
# uv/poetry lock files
poetry.lock
uv.lock
# PyInstaller # PyInstaller
# Usually these files are written by a python script from a template # Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it. # before PyInstaller builds the exe, so as to inject date/other infos into it.

View File

@@ -13,25 +13,26 @@ repos:
- id: check-toml - id: check-toml
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1.29.10
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.19.0 rev: v3.19.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2 rev: v0.9.6
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
hooks:
- id: poetry-check
- id: poetry-lock
args:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks - repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2 rev: v8.23.3
hooks: hooks:
- id: gitleaks - id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.3.1
hooks:
- id: zizmor

View File

@@ -129,38 +129,71 @@ Follow these steps to start contributing:
🚨 **Do not** work on the `main` branch. 🚨 **Do not** work on the `main` branch.
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies. 4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it. Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda or miniconda: Set up a development environment with conda or miniconda:
```bash ```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
``` ```
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library: If you're using `uv`, it can manage python versions so you can instead do:
```bash ```bash
poetry install --sync --extras "dev test" uv venv --python 3.10 && source .venv/bin/activate
```
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
using `poetry`
```bash
poetry sync --extras "dev test"
```
using `uv`
```bash
uv sync --extra dev --extra test
``` ```
You can also install the project with all its dependencies (including environments): You can also install the project with all its dependencies (including environments):
using `poetry`
```bash ```bash
poetry install --sync --all-extras poetry sync --all-extras
```
using `uv`
```bash
uv sync --all-extras
``` ```
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing. > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies. Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
The equivalent of `pip install some-package`, would just be: The equivalent of `pip install some-package`, would just be:
using `poetry`
```bash ```bash
poetry add some-package poetry add some-package
``` ```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. using `uv`
```bash ```bash
poetry lock --no-update uv add some-package
``` ```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
using `poetry`
```bash
poetry lock
```
using `uv`
```bash
uv lock
```
5. Develop the features on your branch. 5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite As you work on the features, you should make sure that the test suite
@@ -195,7 +228,7 @@ Follow these steps to start contributing:
git commit git commit
``` ```
Note, if you already commited some changes that have a wrong formatting, you can use: Note, if you already committed some changes that have a wrong formatting, you can use:
```bash ```bash
pre-commit run --all-files pre-commit run --all-files
``` ```

236
Makefile
View File

@@ -2,10 +2,10 @@
PYTHON_PATH := $(shell which python) PYTHON_PATH := $(shell which python)
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python # If uv is installed and a virtual environment exists, use it
POETRY_CHECK := $(shell command -v poetry) UV_CHECK := $(shell command -v uv)
ifneq ($(POETRY_CHECK),) ifneq ($(UV_CHECK),)
PYTHON_PATH := $(shell poetry run which python) PYTHON_PATH := $(shell .venv/bin/python)
endif endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH) export PATH := $(dir $(PYTHON_PATH)):$(PATH)
@@ -20,171 +20,109 @@ build-gpu:
test-end-to-end: test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-act-ete-train ${MAKE} DEVICE=$(DEVICE) test-act-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-resume
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval ${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ --policy.type=act \
policy.dim_model=64 \ --policy.dim_model=64 \
env=aloha \ --policy.n_action_steps=20 \
wandb.enable=False \ --policy.chunk_size=20 \
training.offline_steps=2 \ --env.type=aloha \
training.online_steps=0 \ --env.episode_length=5 \
eval.n_episodes=1 \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
eval.batch_size=1 \ --dataset.image_transforms.enable=true \
device=$(DEVICE) \ --dataset.episodes="[0]" \
training.save_checkpoint=true \ --batch_size=2 \
training.save_freq=2 \ --steps=4 \
policy.n_action_steps=20 \ --eval_freq=2 \
policy.chunk_size=20 \ --eval.n_episodes=1 \
training.batch_size=2 \ --eval.batch_size=1 \
training.image_transforms.enable=true \ --save_freq=2 \
hydra.run.dir=tests/outputs/act/ --save_checkpoint=true \
--log_freq=1 \
--wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/act/
test-act-ete-train-resume:
python lerobot/scripts/train.py \
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
--resume=true
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
eval.n_episodes=1 \ --env.type=aloha \
eval.batch_size=1 \ --env.episode_length=5 \
env.episode_length=8 \ --eval.n_episodes=1 \
device=$(DEVICE) \ --eval.batch_size=1 \
--device=$(DEVICE)
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_amp/ \
training.image_transforms.enable=true \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ --policy.type=diffusion \
policy.down_dims=\[64,128,256\] \ --policy.down_dims='[64,128,256]' \
policy.diffusion_step_embed_dim=32 \ --policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \ --policy.num_inference_steps=10 \
env=pusht \ --env.type=pusht \
wandb.enable=False \ --env.episode_length=5 \
training.offline_steps=2 \ --dataset.repo_id=lerobot/pusht \
training.online_steps=0 \ --dataset.image_transforms.enable=true \
eval.n_episodes=1 \ --dataset.episodes="[0]" \
eval.batch_size=1 \ --batch_size=2 \
device=$(DEVICE) \ --steps=2 \
training.save_checkpoint=true \ --eval_freq=2 \
training.save_freq=2 \ --eval.n_episodes=1 \
training.batch_size=2 \ --eval.batch_size=1 \
training.image_transforms.enable=true \ --save_checkpoint=true \
hydra.run.dir=tests/outputs/diffusion/ --save_freq=2 \
--log_freq=1 \
--wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
eval.n_episodes=1 \ --env.type=pusht \
eval.batch_size=1 \ --env.episode_length=5 \
env.episode_length=8 \ --eval.n_episodes=1 \
device=$(DEVICE) \ --eval.batch_size=1 \
--device=$(DEVICE)
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ --policy.type=tdmpc \
env=xarm \ --env.type=xarm \
env.task=XarmLift-v0 \ --env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium \ --env.episode_length=5 \
wandb.enable=False \ --dataset.repo_id=lerobot/xarm_lift_medium \
training.offline_steps=2 \ --dataset.image_transforms.enable=true \
training.online_steps=0 \ --dataset.episodes="[0]" \
eval.n_episodes=1 \ --batch_size=2 \
eval.batch_size=1 \ --steps=2 \
env.episode_length=2 \ --eval_freq=2 \
device=$(DEVICE) \ --eval.n_episodes=1 \
training.save_checkpoint=true \ --eval.batch_size=1 \
training.save_freq=2 \ --save_checkpoint=true \
training.batch_size=2 \ --save_freq=2 \
training.image_transforms.enable=true \ --log_freq=1 \
hydra.run.dir=tests/outputs/tdmpc/ --wandb.enable=false \
--device=$(DEVICE) \
test-tdmpc-ete-train-with-online: --output_dir=tests/outputs/tdmpc/
python lerobot/scripts/train.py \
env=pusht \
env.gym.obs_type=environment_state_agent_pos \
policy=tdmpc_pusht_keypoints \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=10 \
device=$(DEVICE) \
training.offline_steps=2 \
training.online_steps=20 \
training.save_checkpoint=false \
training.save_freq=10 \
training.batch_size=2 \
training.online_rollout_n_episodes=2 \
training.online_rollout_batch_size=2 \
training.online_steps_between_rollouts=10 \
training.online_buffer_capacity=15 \
eval.use_async_envs=true \
hydra.run.dir=tests/outputs/tdmpc_online/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
eval.n_episodes=1 \ --env.type=xarm \
eval.batch_size=1 \ --env.episode_length=5 \
env.episode_length=8 \ --env.task=XarmLift-v0 \
device=$(DEVICE) \ --eval.n_episodes=1 \
--eval.batch_size=1 \
test-default-ete-eval: --device=$(DEVICE)
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
test-act-pusht-tutorial:
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
python lerobot/scripts/train.py \
policy=created_by_Makefile.yaml \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=$(DEVICE) \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act_pusht/
rm lerobot/configs/policy/created_by_Makefile.yaml

View File

@@ -68,7 +68,7 @@
### Acknowledgment ### Acknowledgment
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). - Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). - Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). - Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support. - Thanks to Antonio Loquercio and Ashish Kumar for their early support.
@@ -122,10 +122,7 @@ wandb login
├── examples # contains demonstration examples, start here to learn about LeRobot ├── examples # contains demonstration examples, start here to learn about LeRobot
| └── advanced # contains even more examples for those who have mastered the basics | └── advanced # contains even more examples for those who have mastered the basics
├── lerobot ├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line | ├── configs # contains config classes with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
| ├── common # contains classes and utilities | ├── common # contains classes and utilities
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm | | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | ├── envs # various sim environments: aloha, pusht, xarm | | ├── envs # various sim environments: aloha, pusht, xarm
@@ -213,7 +210,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
- videos are stored in mp4 format to save space - videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files - metadata are stored in plain json/jsonl files
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location. Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
### Evaluate a pretrained policy ### Evaluate a pretrained policy
@@ -222,87 +219,48 @@ Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrat
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
```bash ```bash
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
-p lerobot/diffusion_pusht \ --policy.path=lerobot/diffusion_pusht \
eval.n_episodes=10 \ --env.type=pusht \
eval.batch_size=10 --eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
``` ```
Note: After training your own policy, you can re-evaluate the checkpoints with: Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash ```bash
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model python lerobot/scripts/eval.py --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
``` ```
See `python lerobot/scripts/eval.py --help` for more instructions. See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy ### Train your own policy
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task: To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
```bash A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
python lerobot/scripts/train.py \
policy=act \
env=aloha \
env.task=AlohaInsertion-v0 \
dataset_repo_id=lerobot/aloha_sim_insertion_human \
```
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
```bash
hydra.run.dir=your/new/experiment/dir
```
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
```bash
checkpoints
├── 000250 # checkpoint_dir for training step 250
│ ├── pretrained_model # Hugging Face pretrained model dir
│ │ ├── config.json # Hugging Face pretrained model config
│ │ ├── config.yaml # consolidated Hydra config
│ │ ├── model.safetensors # model weights
│ │ └── README.md # Hugging Face model card
│ └── training_state.pth # optimizer/scheduler/rng state and training step
```
To resume training from a checkpoint, you can add these to the `train.py` python command:
```bash
hydra.run.dir=your/original/experiment/dir resume=true
```
It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash
wandb.enable=true
```
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
![](media/wandb.png) ![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions. Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
#### Reproduce state-of-the-art (SOTA) #### Reproduce state-of-the-art (SOTA)
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running: We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
You can reproduce their training by loading the config from their run. Simply running:
```bash ```bash
python lerobot/scripts/train.py policy=diffusion env=pusht python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
``` ```
reproduces SOTA results for Diffusion Policy on the PushT task. reproduces SOTA results for Diffusion Policy on the PushT task.
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
## Contribute ## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
### Add a new dataset <!-- ### Add a new dataset
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash ```bash
@@ -320,7 +278,7 @@ python lerobot/scripts/push_dataset_to_hub.py \
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions. See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py). If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py). -->
### Add a pretrained policy ### Add a pretrained policy
@@ -330,7 +288,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain: You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. - `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility. - `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility.
To upload these to the hub, run the following: To upload these to the hub, run the following:
```bash ```bash

View File

@@ -21,7 +21,7 @@ How to decode videos?
## Variables ## Variables
**Image content & size** **Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution). We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets: For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. - `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`. Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario: Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`), - `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded. However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
**Average Structural Similarity Index Measure (higher is better)** **Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity. `avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes. One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility: h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers. - `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly. - `yuv444p` offers higher color fidelity but might not be supported as broadly.
@@ -114,9 +114,9 @@ We tried to measure the most impactful parameters for both encoding and decoding
Additional encoding parameters exist that are not included in this benchmark. In particular: Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1. - `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.). - `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters. See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few: Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio` - `torchaudio`

View File

@@ -1,32 +1,29 @@
# Configure image # Configure image
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
ARG PYTHON_VERSION ARG PYTHON_VERSION
ARG DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH" ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot # Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer
COPY . /lerobot COPY . /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"
# Execute in bash shell rather than python # Execute in bash shell rather than python
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
sed gawk grep curl wget zip unzip \ sed gawk grep curl wget zip unzip \
tcpdump sysstat screen tmux \ tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
speech-dispatcher \ speech-dispatcher portaudio19-dev libgeos-dev \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@@ -1,30 +1,24 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 FROM nvidia/cuda:12.4.1-base-ubuntu22.04
# Configure image # Configure environment variables
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH" ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot # Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer
COPY . /lerobot COPY . /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@@ -1,58 +1,99 @@
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot. # Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
## Source the parts ## Table of Contents
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already. - [A. Source the parts](#a-source-the-parts)
- [B. Install LeRobot](#b-install-lerobot)
- [C. Configure the motors](#c-configure-the-motors)
- [D. Assemble the arms](#d-assemble-the-arms)
- [E. Calibrate](#e-calibrate)
- [F. Teleoperate](#f-teleoperate)
- [G. Record a dataset](#g-record-a-dataset)
- [H. Visualize a dataset](#h-visualize-a-dataset)
- [I. Replay an episode](#i-replay-an-episode)
- [J. Train a policy](#j-train-a-policy)
- [K. Evaluate your policy](#k-evaluate-your-policy)
- [L. More Information](#l-more-information)
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. ## A. Source the parts
## Install LeRobot Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts,
and advice if it's your first time printing or if you don't own a 3D printer.
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
## B. Install LeRobot
> [!TIP]
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
On your computer: On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): #### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
#### 2. Restart shell
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
#### 3. Create and activate a fresh conda environment for lerobot
<details>
<summary><strong>Video install instructions</strong></summary>
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
</details>
```bash ```bash
mkdir -p ~/miniconda3 conda create -y -n lerobot python=3.10
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
``` ```
2. Restart shell or `source ~/.bashrc` Then activate your conda environment (do this each time you open a shell to use lerobot!):
3. Create and activate a fresh conda environment for lerobot
```bash ```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot conda activate lerobot
``` ```
4. Clone LeRobot: #### 4. Clone LeRobot:
```bash ```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot git clone https://github.com/huggingface/lerobot.git ~/lerobot
``` ```
5. Install LeRobot with dependencies for the feetech motors: #### 5. Install LeRobot with dependencies for the feetech motors:
```bash ```bash
cd ~/lerobot && pip install -e ".[feetech]" cd ~/lerobot && pip install -e ".[feetech]"
``` ```
For Linux only (not Mac), install extra dependencies for recording datasets: *EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
```bash ```bash
conda install -y -c conda-forge ffmpeg conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0" conda install -y -c conda-forge "opencv>=4.10.0"
``` ```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
## C. Configure the motors
## Configure the motors > [!NOTE]
> Throughout this tutorial you will find videos on how to do the steps, the full video tutorial can be found here: [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I).
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below. ### 1. Find the USB ports associated to each arm
**Find USB ports associated to your arms** Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6 (F1...F6 and L1...L6).
To find the correct ports for each arm, run the utility script twice:
#### a. Run the script to find port
<details>
<summary><strong>Video finding port</strong></summary>
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
</details>
To find the port for each bus servo adapter, run the utility script:
```bash ```bash
python lerobot/scripts/find_motors_bus_port.py python lerobot/scripts/find_motors_bus_port.py
``` ```
#### b. Example outputs
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
``` ```
Finding all available ports for the MotorBus. Finding all available ports for the MotorBus.
@@ -64,7 +105,6 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751 The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable. Reconnect the usb cable.
``` ```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
``` ```
Finding all available ports for the MotorBus. Finding all available ports for the MotorBus.
@@ -77,14 +117,73 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable. Reconnect the usb cable.
``` ```
Troubleshooting: On Linux, you might need to give access to the USB ports by running: #### c. Troubleshooting
On Linux, you might need to give access to the USB ports by running:
```bash ```bash
sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1 sudo chmod 666 /dev/ttyACM1
``` ```
**Configure your motors** #### d. Update config file
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
```python
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/so100"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
```
### 2. Assembling the Base
Let's begin with assembling the follower arm base
#### a. Set IDs for all 12 motors
<details>
<summary><strong>Video configuring motor</strong></summary>
<video src="https://github.com/user-attachments/assets/ef9b3317-2e11-4858-b9d3-f0a02fb48ecf"></video>
<video src="https://github.com/user-attachments/assets/f36b5ed5-c803-4ebe-8947-b39278776a0d"></video>
</details>
Plug your first motor F1 and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate. Replace the text after --port to the corresponding follower control board port and run this command in cmd:
```bash ```bash
python lerobot/scripts/configure_motor.py \ python lerobot/scripts/configure_motor.py \
--port /dev/tty.usbmodem58760432961 \ --port /dev/tty.usbmodem58760432961 \
@@ -94,7 +193,8 @@ python lerobot/scripts/configure_motor.py \
--ID 1 --ID 1
``` ```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees). > [!NOTE]
> These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2. Then unplug your motor and plug the second motor and set its ID to 2.
```bash ```bash
@@ -108,23 +208,50 @@ python lerobot/scripts/configure_motor.py \
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm. Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
**Remove the gears of the 6 leader motors**
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
**Add motor horn to the motors** #### b. Remove the gears of the 6 leader motors
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
<details>
<summary><strong>Video removing gears</strong></summary>
<video src="https://github.com/user-attachments/assets/0c95b88c-5b85-413d-ba19-aee2f864f2a7"></video>
</details>
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### c. Add motor horn to all 12 motors
<details>
<summary><strong>Video adding motor horn</strong></summary>
<video src="https://github.com/user-attachments/assets/ef3391a4-ad05-4100-b2bd-1699bf86c969"></video>
</details>
Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated. Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
## Assemble the arms ## D. Assemble the arms
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm. <details>
<summary><strong>Video assembling arms</strong></summary>
## Calibrate <video src="https://github.com/user-attachments/assets/488a39de-0189-4461-9de3-05b015f90cca"></video>
</details>
Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
## E. Calibrate
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another. Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
**Manual calibration of follower arm** #### a. Manual calibration of follower arm
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
> [!IMPORTANT]
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
You will need to move the follower arm to these positions sequentially: You will need to move the follower arm to these positions sequentially:
@@ -134,13 +261,15 @@ You will need to move the follower arm to these positions sequentially:
Make sure both arms are connected and run this script to launch manual calibration: Make sure both arms are connected and run this script to launch manual calibration:
```bash ```bash
python lerobot/scripts/control_robot.py calibrate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--robot-overrides '~cameras' --arms main_follower --robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_follower"]'
``` ```
**Manual calibration of leader arm** #### b. Manual calibration of leader arm
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially: Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position | | 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---| |---|---|---|
@@ -148,31 +277,34 @@ Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5
Run this script to launch manual calibration: Run this script to launch manual calibration:
```bash ```bash
python lerobot/scripts/control_robot.py calibrate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--robot-overrides '~cameras' --arms main_leader --robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_leader"]'
``` ```
## Teleoperate ## F. Teleoperate
**Simple teleop** **Simple teleop**
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras): Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--robot-overrides '~cameras' \ --robot.cameras='{}' \
--display-cameras 0 --control.type=teleoperate
``` ```
**Teleop with displaying cameras** #### a. Teleop with displaying cameras
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml --robot.type=so100 \
--control.type=teleoperate
``` ```
## Record a dataset ## G. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100. Once you're familiar with teleoperation, you can record your first dataset with SO-100.
@@ -189,87 +321,95 @@ echo $HF_USER
Record 2 episodes and upload your dataset to the hub: Record 2 episodes and upload your dataset to the hub:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/so100_test \ --control.fps=30 \
--tags so100 tutorial \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/so100_test \
--episode-time-s 40 \ --control.tags='["so100","tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 2 \ --control.episode_time_s=30 \
--push-to-hub 1 --control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
``` ```
## Visualize a dataset Note: You can resume recording by adding `--control.resume=true`.
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: ## H. Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash ```bash
echo ${HF_USER}/so100_test echo ${HF_USER}/so100_test
``` ```
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with: If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
```bash ```bash
python lerobot/scripts/visualize_dataset_html.py \ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/so100_test --repo-id ${HF_USER}/so100_test \
--local-files-only 1
``` ```
## Replay an episode ## I. Replay an episode
Now try to replay the first episode on your robot: Now try to replay the first episode on your robot:
```bash ```bash
python lerobot/scripts/control_robot.py replay \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--fps 30 \ --control.type=replay \
--repo-id ${HF_USER}/so100_test \ --control.fps=30 \
--episode 0 --control.repo_id=${HF_USER}/so100_test \
--control.episode=0
``` ```
## Train a policy ## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/so100_test \ --dataset.repo_id=${HF_USER}/so100_test \
policy=act_so100_real \ --policy.type=act \
env=so100_real \ --output_dir=outputs/train/act_so100_test \
hydra.run.dir=outputs/train/act_so100_test \ --job_name=act_so100_test \
hydra.job.name=act_so100_test \ --device=cuda \
device=cuda \ --wandb.enable=true
wandb.enable=true
``` ```
Let's explain it: Let's explain it:
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/so100_test`. 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
2. We provided the policy with `policy=act_so100_real`. This loads configurations from [`lerobot/configs/policy/act_so100_real.yaml`](../lerobot/configs/policy/act_so100_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml). 4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. 5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`. Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
## Evaluate your policy ## K. Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \ --robot.type=so100 \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/eval_act_so100_test \ --control.fps=30 \
--tags so100 tutorial eval \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/eval_act_so100_test \
--episode-time-s 40 \ --control.tags='["tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 10 \ --control.episode_time_s=30 \
-p outputs/train/act_so100_test/checkpoints/last/pretrained_model --control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_so100_test/checkpoints/last/pretrained_model
``` ```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`). 1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`).
## More ## L. More Information
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
If you have any question or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039). > [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).

463
examples/11_use_lekiwi.md Normal file
View File

@@ -0,0 +1,463 @@
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
## Table of Contents
- [A. Source the parts](#a-source-the-parts)
- [B. Install software Pi](#b-install-software-on-pi)
- [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop)
- [D. Assemble the arms](#d-assembly)
- [E. Calibrate](#e-calibration)
- [F. Teleoperate](#f-teleoperate)
- [G. Record a dataset](#g-record-a-dataset)
- [H. Visualize a dataset](#h-visualize-a-dataset)
- [I. Replay an episode](#i-replay-an-episode)
- [J. Train a policy](#j-train-a-policy)
- [K. Evaluate your policy](#k-evaluate-your-policy)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371).
## A. Source the parts
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer.
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
## B. Install software on Pi
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
### Install OS
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
### Setup SSH
After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
### Install LeRobot
On your Raspberry Pi:
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
#### 2. Restart shell
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
#### 3. Create and activate a fresh conda environment for lerobot
<details>
<summary><strong>Video install instructions</strong></summary>
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
</details>
```bash
conda create -y -n lerobot python=3.10
```
Then activate your conda environment (do this each time you open a shell to use lerobot!):
```bash
conda activate lerobot
```
#### 4. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
## C. Install LeRobot on laptop
If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi.
> [!TIP]
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
On your computer:
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
#### 2. Restart shell
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
#### 3. Create and activate a fresh conda environment for lerobot
<details>
<summary><strong>Video install instructions</strong></summary>
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
</details>
```bash
conda create -y -n lerobot python=3.10
```
Then activate your conda environment (do this each time you open a shell to use lerobot!):
```bash
conda activate lerobot
```
#### 4. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
# D. Assembly
First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base.
## SO100 Arms
### Configure motors
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
### Assemble arms
[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms)
## Mobile base (LeKiwi)
[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi)
### Update config
Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script.
#### a. Run the script to find port
<details>
<summary><strong>Video finding port</strong></summary>
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
</details>
To find the port for each bus servo adapter, run the utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
```
#### b. Example outputs
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
[...Disconnect leader arm and press Enter...]
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
[...Disconnect follower arm and press Enter...]
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
#### c. Troubleshooting
On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
#### d. Update config file
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
```python
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Network Configuration
ip: str = "172.17.133.91"
port: int = 5555
video_port: int = 5556
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480),
"mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480),
}
)
calibration_dir: str = ".cache/calibration/lekiwi"
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0077581",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/ttyACM0",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
"left_wheel": (7, "sts3215"),
"back_wheel": (8, "sts3215"),
"right_wheel": (9, "sts3215"),
},
),
}
)
mock: bool = False
```
# E. Calibration
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
### Calibrate follower arm (on mobile base)
> [!IMPORTANT]
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
You will need to move the follower arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---|
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_follower"]'
```
### Calibrate leader arm
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---|
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
Run this script (on your laptop/pc) to launch manual calibration:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_leader"]'
```
# F. Teleoperate
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--control.type=remote_robot
```
Then on your laptop, also run `conda activate lerobot` and this script:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--control.type=teleoperate \
--control.fps=30
```
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|------------|-------------------|-----------------------|
| Fast | 0.4 | 90 |
| Medium | 0.25 | 60 |
| Slow | 0.1 | 30 |
| Key | Action |
|------|--------------------------------|
| W | Move forward |
| A | Move left |
| S | Move backward |
| D | Move right |
| Z | Turn left |
| X | Turn right |
| R | Increase speed |
| F | Decrease speed |
> [!TIP]
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
## Troubleshoot communication
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
### 1. Verify IP Address Configuration
Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line):
```bash
hostname -I
```
### 2. Check if Pi is reachable from laptop/pc
Try pinging the Raspberry Pi from your laptop:
```bach
ping <your_pi_ip_address>
```
If the ping fails:
- Ensure the Pi is powered on and connected to the same network.
- Check if SSH is enabled on the Pi.
### 3. Try SSH connection
If you can't SSH into the Pi, it might not be properly connected. Use:
```bash
ssh <your_pi_user_name>@<your_pi_ip_address>
```
If you get a connection error:
- Ensure SSH is enabled on the Pi by running:
```bash
sudo raspi-config
```
Then navigate to: **Interfacing Options -> SSH** and enable it.
### 4. Same config file
Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same.
# G. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/lekiwi_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
Note: You can resume recording by adding `--control.resume=true`.
# H. Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/lekiwi_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/lekiwi_test \
--local-files-only 1
```
# I. Replay an episode
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/lekiwi_test \
--control.episode=0
```
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=${HF_USER}/lekiwi_test \
--policy.type=act \
--output_dir=outputs/train/act_lekiwi_test \
--job_name=act_lekiwi_test \
--device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
## K. Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
--control.type=record \
--control.fps=30 \
--control.single_task="Drive to the red block and pick it up" \
--control.repo_id=${HF_USER}/eval_act_lekiwi_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`).

View File

@@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
## Source the parts ## Source the parts
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already. Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. **Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
@@ -83,6 +83,54 @@ sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1 sudo chmod 666 /dev/ttyACM1
``` ```
#### Update config file
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`MossRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
```python
@RobotConfig.register_subclass("moss")
@dataclass
class MossRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/moss"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
```
**Configure your motors** **Configure your motors**
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate: Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
```bash ```bash
@@ -134,9 +182,11 @@ You will need to move the follower arm to these positions sequentially:
Make sure both arms are connected and run this script to launch manual calibration: Make sure both arms are connected and run this script to launch manual calibration:
```bash ```bash
python lerobot/scripts/control_robot.py calibrate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--robot-overrides '~cameras' --arms main_follower --robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_follower"]'
``` ```
**Manual calibration of leader arm** **Manual calibration of leader arm**
@@ -148,9 +198,11 @@ Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMi
Run this script to launch manual calibration: Run this script to launch manual calibration:
```bash ```bash
python lerobot/scripts/control_robot.py calibrate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--robot-overrides '~cameras' --arms main_leader --robot.cameras='{}' \
--control.type=calibrate \
--control.arms='["main_leader"]'
``` ```
## Teleoperate ## Teleoperate
@@ -158,18 +210,19 @@ python lerobot/scripts/control_robot.py calibrate \
**Simple teleop** **Simple teleop**
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras): Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--robot-overrides '~cameras' \ --robot.cameras='{}' \
--display-cameras 0 --control.type=teleoperate
``` ```
**Teleop with displaying cameras** **Teleop with displaying cameras**
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml --robot.type=moss \
--control.type=teleoperate
``` ```
## Record a dataset ## Record a dataset
@@ -189,40 +242,46 @@ echo $HF_USER
Record 2 episodes and upload your dataset to the hub: Record 2 episodes and upload your dataset to the hub:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/moss_test \ --control.fps=30 \
--tags moss tutorial \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/moss_test \
--episode-time-s 40 \ --control.tags='["moss","tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 2 \ --control.episode_time_s=30 \
--push-to-hub 1 --control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
``` ```
Note: You can resume recording by adding `--control.resume=true`.
## Visualize a dataset ## Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash ```bash
echo ${HF_USER}/moss_test echo ${HF_USER}/moss_test
``` ```
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with: If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
```bash ```bash
python lerobot/scripts/visualize_dataset_html.py \ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/moss_test --repo-id ${HF_USER}/moss_test \
--local-files-only 1
``` ```
## Replay an episode ## Replay an episode
Now try to replay the first episode on your robot: Now try to replay the first episode on your robot:
```bash ```bash
python lerobot/scripts/control_robot.py replay \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--fps 30 \ --control.type=replay \
--repo-id ${HF_USER}/moss_test \ --control.fps=30 \
--episode 0 --control.repo_id=${HF_USER}/moss_test \
--control.episode=0
``` ```
## Train a policy ## Train a policy
@@ -230,20 +289,18 @@ python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/moss_test \ --dataset.repo_id=${HF_USER}/moss_test \
policy=act_moss_real \ --policy.type=act \
env=moss_real \ --output_dir=outputs/train/act_moss_test \
hydra.run.dir=outputs/train/act_moss_test \ --job_name=act_moss_test \
hydra.job.name=act_moss_test \ --device=cuda \
device=cuda \ --wandb.enable=true
wandb.enable=true
``` ```
Let's explain it: Let's explain it:
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/moss_test`. 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
2. We provided the policy with `policy=act_moss_real`. This loads configurations from [`lerobot/configs/policy/act_moss_real.yaml`](../lerobot/configs/policy/act_moss_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml). 4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. 5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`. Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
@@ -252,21 +309,24 @@ Training should take several hours. You will find checkpoints in `outputs/train/
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/moss.yaml \ --robot.type=moss \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/eval_act_moss_test \ --control.fps=30 \
--tags moss tutorial eval \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/eval_act_moss_test \
--episode-time-s 40 \ --control.tags='["tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 10 \ --control.episode_time_s=30 \
-p outputs/train/act_moss_test/checkpoints/last/pretrained_model --control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_moss_test/checkpoints/last/pretrained_model
``` ```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_moss_test`). 1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_moss_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_moss_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_moss_test`).
## More ## More

View File

@@ -1,6 +1,11 @@
""" """
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
```bash
pip install -e ".[pusht]"`
```
""" """
from pathlib import Path from pathlib import Path
@@ -10,7 +15,6 @@ import gymnasium as gym
import imageio import imageio
import numpy import numpy
import torch import torch
from huggingface_hub import snapshot_download
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -18,25 +22,15 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
output_directory = Path("outputs/eval/example_pusht_diffusion") output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True) output_directory.mkdir(parents=True, exist_ok=True)
# Download the diffusion policy for pusht environment # Select your device
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) device = "cuda"
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
policy.eval()
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU is available. Device set to:", device)
else:
device = torch.device("cpu")
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
policy.diffusion.num_inference_steps = 10
policy.to(device)
# Initialize evaluation environment to render two observation types: # Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment # an image of the scene and state/position of the agent. The environment
@@ -47,7 +41,17 @@ env = gym.make(
max_episode_steps=300, max_episode_steps=300,
) )
# Reset the policy and environmens to prepare for rollout # We can verify that the shapes of the features expected by the policy match the ones from the observations
# produced by the environment
print(policy.config.input_features)
print(env.observation_space)
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
# environment
print(policy.config.output_features)
print(env.action_space)
# Reset the policy and environments to prepare for rollout
policy.reset() policy.reset()
numpy_observation, info = env.reset(seed=42) numpy_observation, info = env.reset(seed=42)

View File

@@ -8,72 +8,99 @@ from pathlib import Path
import torch import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.types import FeatureType
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example.) def main():
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. # Create a directory to store the training checkpoint.
training_steps = 5000 output_directory = Path("outputs/train/example_pusht_diffusion")
device = torch.device("cuda") output_directory.mkdir(parents=True, exist_ok=True)
log_freq = 250
# Set up the dataset. # # Select your device
delta_timestamps = { device = torch.device("cuda")
# Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Set up the the policy. # Number of offline training steps (we'll only do offline training for this example.)
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. # Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
# For this example, no arguments need to be passed because the defaults are set up for PushT. training_steps = 5000
# If you're doing something different, you will likely need to change at least some of the defaults. log_freq = 1
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) # When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
# - input/output shapes: to properly size the policy
# - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
# Create dataloader for offline training. # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
dataloader = torch.utils.data.DataLoader( # we'll just use the defaults and so no arguments other than input/output features need to be passed.
dataset, cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
num_workers=4,
batch_size=64,
shuffle=True,
pin_memory=device != torch.device("cpu"),
drop_last=True,
)
# Run training loop. # We can now instantiate our policy with this config and the dataset stats.
step = 0 policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
done = False policy.train()
while not done: policy.to(device)
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss = output_dict["loss"]
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0: # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
print(f"step: {step} loss: {loss.item():.3f}") # which can differ for inputs, outputs and rewards (if there are some).
step += 1 delta_timestamps = {
if step >= training_steps: "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
done = True "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
break "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}
# Save a policy checkpoint. # In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
policy.save_pretrained(output_directory) delta_timestamps = {
# Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# We can then instantiate the dataset with these delta_timestamps configuration.
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Then we create our optimizer and dataloader for offline training.
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=64,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Run training loop.
step = 0
done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save a policy checkpoint.
policy.save_pretrained(output_directory)
if __name__ == "__main__":
main()

View File

@@ -1,193 +1,223 @@
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run. This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
## The training script ## The training script
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following: LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment). - Initialize/load a configuration for the following steps using.
- Makes a simulation environment. - Instantiates a dataset.
- Makes a dataset corresponding to that simulation environment. - (Optional) Instantiates a simulation environment corresponding to that dataset.
- Makes a policy. - Instantiates a policy.
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing. - Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
## Basics of how we use Hydra ## Overview of the configuration system
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
First, `lerobot/configs` has a directory structure like this:
```
.
├── default.yaml
├── env
│ ├── aloha.yaml
│ ├── pusht.yaml
│ └── xarm.yaml
└── policy
├── act.yaml
├── diffusion.yaml
└── tdmpc.yaml
```
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
When you run the training script with
In the training script, the main function `train` expects a `TrainPipelineConfig` object:
```python ```python
python lerobot/scripts/train.py # train.py
@parser.wrap()
def train(cfg: TrainPipelineConfig):
``` ```
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this: You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
```yaml When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
defaults:
- _self_ Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
- env: pusht ```python
- policy: diffusion @dataclass
class TrainPipelineConfig:
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
```
in which `DatasetConfig` for example is defined as such:
```python
@dataclass
class DatasetConfig:
repo_id: str
episodes: list[int] | None = None
video_backend: str = "pyav"
``` ```
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_. This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`.
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`. By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
```bash ## Specifying values from the CLI
python lerobot/scripts/train.py
```
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
## Overriding configuration parameters in the CLI
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
```yaml
# lerobot/configs/env/aloha.yaml
env:
task: AlohaInsertion-v0
```
And if you look in `policy/act.yaml` you will see something like:
```yaml
# lerobot/configs/policy/act.yaml
dataset_repo_id: lerobot/aloha_sim_insertion_human
```
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
```diff
# lerobot/configs/env/aloha.yaml
env:
- task: AlohaInsertion-v0
+ task: AlohaTransferCube-v0
```
Then, we'd also need to switch to using the cube transfer dataset.
```diff
# lerobot/configs/policy/act.yaml
-dataset_repo_id: lerobot/aloha_sim_insertion_human
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
```
Then, you'd be able to run:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
and you'd be training and evaluating on the cube transfer task.
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ --dataset.repo_id=lerobot/pusht \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \ --policy.type=diffusion \
env=aloha \ --env.type=pusht
env.task=AlohaTransferCube-v0
``` ```
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._ Let's break this down:
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human. - To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \ --policy.type=act \
device=cuda --dataset.repo_id=lerobot/aloha_sim_insertion_human \
env=aloha \ --env.type=aloha \
env.task=AlohaTransferCube-v0 \ --output_dir=outputs/train/act_aloha_insertion
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
policy=act \
training.eval_freq=10000 \
training.log_freq=250 \
training.offline_steps=100000 \
training.save_model=true \
training.save_freq=25000 \
eval.n_episodes=50 \
eval.batch_size=50 \
wandb.enable=false \
``` ```
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output. We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
## Using a configuration file not in `lerobot/configs`
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
```bash ```bash
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION python lerobot/scripts/train.py \
--policy.type=act \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
--env.type=aloha \
--env.task=AlohaTransferCube-v0 \
--output_dir=outputs/train/act_aloha_transfer
``` ```
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax. ## Loading from a config file
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do: Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used:
```json
{
"dataset": {
"repo_id": "lerobot/aloha_sim_transfer_cube_human",
"episodes": null,
...
},
"env": {
"type": "aloha",
"task": "AlohaTransferCube-v0",
"fps": 50,
...
},
"policy": {
"type": "act",
"n_obs_steps": 1,
...
},
...
}
```
We can then simply load the config values from this file using:
```bash ```bash
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config python lerobot/scripts/train.py \
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
--output_dir=outputs/train/act_aloha_transfer_2
```
`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly.
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
--output_dir=outputs/train/act_aloha_transfer_2
--policy.n_action_steps=80
```
> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next.
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
```bash
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
```
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
## Resume training
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here.
Let's reuse the command from the previous run and add a few more options:
```bash
python lerobot/scripts/train.py \
--policy.type=act \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
--env.type=aloha \
--env.task=AlohaTransferCube-v0 \
--log_freq=25 \
--save_freq=100 \
--output_dir=outputs/train/run_resumption
``` ```
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`). Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal:
```
INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
```
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true
```
You should see from the logging that your training picks up from where it left off.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
You could double the number of steps of the previous run with:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true \
--steps=200000
```
## Outputs of a run
In the output directory, there will be a folder called `checkpoints` with the following structure:
```bash
outputs/train/run_resumption/checkpoints
├── 000100 # checkpoint_dir for training step 100
│ ├── pretrained_model/
│ │ ├── config.json # policy config
│ │ ├── model.safetensors # policy weights
│ │ └── train_config.json # train config
│ └── training_state/
│ ├── optimizer_param_groups.json # optimizer param groups
│ ├── optimizer_state.safetensors # optimizer state
│ ├── rng_state.safetensors # rng states
│ ├── scheduler_state.json # scheduler state
│ └── training_step.json # training step
├── 000200
└── last -> 000200 # symlink to the last available checkpoint
```
## Fine-tuning a pre-trained policy
In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub.
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
--env.type=aloha \
--env.task=AlohaInsertion-v0
```
When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy.
## Typical logs and metrics ## Typical logs and metrics
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint. When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one: After that, you will see training log like this one:
``` ```
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774 INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
``` ```
or evaluation log:
or evaluation log like:
``` ```
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266 INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
``` ```
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations: These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
- `smpl`: number of samples seen during training. - `smpl`: number of samples seen during training.
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task. - `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
- `epch`: number of time all unique samples are seen (epoch). - `epch`: number of time all unique samples are seen (epoch).
@@ -200,14 +230,45 @@ These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing. Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
--- ## In short
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch): We'll summarize here the main use cases to remember from this tutorial.
#### Train a policy from scratch CLI
```bash ```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht python lerobot/scripts/train.py \
--policy.type=act \ # <- select 'act' policy
--env.type=pusht \ # <- select 'pusht' environment
--dataset.repo_id=lerobot/pusht # <- train on this dataset
``` ```
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more. #### Train a policy from scratch - config file + CLI
```bash
python lerobot/scripts/train.py \
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
--policy.n_action_steps=80 # <- you may still override values
```
Or in the meantime, happy coding! 🤗 #### Resume/continue a training run
```bash
python lerobot/scripts/train.py \
--config_path=checkpoint/pretrained_model/ \
--resume=true \
--steps=200000 # <- you can change some training parameters
```
#### Fine-tuning
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
--env.type=aloha \
--env.task=AlohaInsertion-v0
```
---
Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task?
If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md).
Or in the meantime, happy training! 🤗

View File

@@ -1,37 +0,0 @@
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
## Basic training resumption
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0 \
training.log_freq=25 \
training.save_checkpoint=true \
training.save_freq=100
```
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
resume=true
```
You should see from the logging that your training picks up from where it left off.
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
---
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
Happy coding! 🤗

View File

@@ -36,9 +36,14 @@ Using `pip`:
pip install -e ".[dynamixel]" pip install -e ".[dynamixel]"
``` ```
Or using `poetry`: Using `poetry`:
```bash ```bash
poetry install --sync --extras "dynamixel" poetry sync --extras "dynamixel"
```
Using `uv`:
```bash
uv sync --extra "dynamixel"
``` ```
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands: /!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
@@ -54,24 +59,53 @@ Then plug the 12V power supply to the motor bus of the follower arm. It has two
Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer. Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer.
*Copy pasting python code* Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
In the upcoming sections, you'll learn about our classes and functions by running some python code, in an interactive session, or by copy-pasting it in a python file. If this is your first time using the tutorial., we highly recommend going through these steps to get deeper intuition about how things work. Once you're more familiar, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial): If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml \ --robot.type=koch \
--robot-overrides '~cameras' # do not instantiate the cameras --control.type=teleoperate
``` ```
It will automatically: It will automatically:
1. Detect and help you correct any motor configuration issues. 1. Identify any missing calibrations and initiate the calibration procedure.
2. Identify any missing calibrations and initiate the calibration procedure. 2. Connect the robot and start teleoperation.
3. Connect the robot and start teleoperation.
### a. Control your motors with DynamixelMotorsBus ### a. Control your motors with DynamixelMotorsBus
You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors. You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors.
**First Configuration of your motors**
You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors.
Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1.
```bash
python lerobot/scripts/configure_motor.py \
--port /dev/tty.usbmodem58760432961 \
--brand dynamixel \
--model xl330-m288 \
--baudrate 1000000 \
--ID 1
```
Then unplug your first motor and plug the second motor and set its ID to 2.
```bash
python lerobot/scripts/configure_motor.py \
--port /dev/tty.usbmodem58760432961 \
--brand dynamixel \
--model xl330-m288 \
--baudrate 1000000 \
--ID 2
```
Redo the process for all your motors until ID 6.
The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_
After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video.
**Instantiate the DynamixelMotorsBus** **Instantiate the DynamixelMotorsBus**
To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`). To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`).
@@ -105,10 +139,10 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable. Reconnect the usb cable.
``` ```
Troubleshooting: On Linux, you might need to give access to the USB ports by running: Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports:
```bash ```bash
sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/tty.usbmodem575E0032081
sudo chmod 666 /dev/ttyACM1 sudo chmod 666 /dev/tty.usbmodem575E0031751
``` ```
*Listing and Configuring Motors* *Listing and Configuring Motors*
@@ -117,13 +151,11 @@ Next, you'll need to list the motors for each arm, including their name, index,
To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier: To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier:
```python ```python
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
leader_port = "/dev/tty.usbmodem575E0031751" leader_config = DynamixelMotorsBusConfig(
follower_port = "/dev/tty.usbmodem575E0032081" port="/dev/tty.usbmodem575E0031751",
leader_arm = DynamixelMotorsBus(
port=leader_port,
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": (1, "xl330-m077"), "shoulder_pan": (1, "xl330-m077"),
@@ -135,8 +167,8 @@ leader_arm = DynamixelMotorsBus(
}, },
) )
follower_arm = DynamixelMotorsBus( follower_config = DynamixelMotorsBusConfig(
port=follower_port, port="/dev/tty.usbmodem575E0032081",
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": (1, "xl430-w250"), "shoulder_pan": (1, "xl430-w250"),
@@ -147,45 +179,57 @@ follower_arm = DynamixelMotorsBus(
"gripper": (6, "xl330-m288"), "gripper": (6, "xl330-m288"),
}, },
) )
leader_arm = DynamixelMotorsBus(leader_config)
follower_arm = DynamixelMotorsBus(follower_config)
``` ```
*Updating the YAML Configuration File* IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
```python
@RobotConfig.register_subclass("koch")
@dataclass
class KochRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
Next, update the port values in the YAML configuration file for the Koch robot at [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the ports you've identified: leader_arms: dict[str, MotorsBusConfig] = field(
```yaml default_factory=lambda: {
[...] "main": DynamixelMotorsBusConfig(
robot_type: koch port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE
leader_arms: motors={
main: # name: (index, model)
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus "shoulder_pan": [1, "xl330-m077"],
port: /dev/tty.usbmodem575E0031751 # <- Update "shoulder_lift": [2, "xl330-m077"],
motors: "elbow_flex": [3, "xl330-m077"],
# name: (index, model) "wrist_flex": [4, "xl330-m077"],
shoulder_pan: [1, "xl330-m077"] "wrist_roll": [5, "xl330-m077"],
shoulder_lift: [2, "xl330-m077"] "gripper": [6, "xl330-m077"],
elbow_flex: [3, "xl330-m077"] },
wrist_flex: [4, "xl330-m077"] ),
wrist_roll: [5, "xl330-m077"] }
gripper: [6, "xl330-m077"] )
follower_arms:
main: follower_arms: dict[str, MotorsBusConfig] = field(
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus default_factory=lambda: {
port: /dev/tty.usbmodem575E0032081 # <- Update "main": DynamixelMotorsBusConfig(
motors: port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
# name: (index, model) motors={
shoulder_pan: [1, "xl430-w250"] # name: (index, model)
shoulder_lift: [2, "xl430-w250"] "shoulder_pan": [1, "xl430-w250"],
elbow_flex: [3, "xl330-m288"] "shoulder_lift": [2, "xl430-w250"],
wrist_flex: [4, "xl330-m288"] "elbow_flex": [3, "xl330-m288"],
wrist_roll: [5, "xl330-m288"] "wrist_flex": [4, "xl330-m288"],
gripper: [6, "xl330-m288"] "wrist_roll": [5, "xl330-m288"],
[...] "gripper": [6, "xl330-m288"],
},
),
}
)
``` ```
Don't forget to set `robot_type: aloha` if you follow this tutorial with [Aloha bimanual robot](aloha-2.github.io) instead of Koch v1.1
This configuration file is used to instantiate your robot across all scripts. We'll cover how this works later on.
**Connect and Configure your Motors** **Connect and Configure your Motors**
Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus. Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus.
@@ -312,27 +356,27 @@ Alternatively, you can unplug the power cord, which will automatically disable t
**Instantiate the ManipulatorRobot** **Instantiate the ManipulatorRobot**
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_arm` and `follower_arm`. Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`.
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms. For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms.
You also need to provide a path to a calibration directory, such as `calibration_dir=".cache/calibration/koch"`. More on this in the next section.
Run the following code to instantiate your manipulator robot: Run the following code to instantiate your manipulator robot:
```python ```python
from lerobot.common.robot_devices.robots.configs import KochRobotConfig
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
robot = ManipulatorRobot( robot_config = KochRobotConfig(
robot_type="koch", leader_arms={"main": leader_config},
leader_arms={"main": leader_arm}, follower_arms={"main": follower_config},
follower_arms={"main": follower_arm}, cameras={}, # We don't use any camera for now
calibration_dir=".cache/calibration/koch",
) )
robot = ManipulatorRobot(robot_config)
``` ```
The `robot_type="koch"` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger. The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `robot_type="aloha"` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected. If you need to run manual calibration, simply update `calibration_dir` to `.cache/calibration/aloha`. For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
**Calibrate and Connect the ManipulatorRobot** **Calibrate and Connect the ManipulatorRobot**
@@ -354,7 +398,7 @@ And here are the corresponding positions for the leader arm:
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details. You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively. During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation. Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
@@ -579,9 +623,11 @@ Note: Some cameras may take a few seconds to warm up, and the first frame might
Finally, run this code to instantiate and connectyour camera: Finally, run this code to instantiate and connectyour camera:
```python ```python
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
camera = OpenCVCamera(camera_index=0) config = OpenCVCameraConfig(camera_index=0)
camera = OpenCVCamera(config)
camera.connect() camera.connect()
color_image = camera.read() color_image = camera.read()
@@ -603,7 +649,7 @@ uint8
With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance: With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance:
```python ```python
camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480) config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480)
``` ```
If the provided arguments are not compatible with the camera, an exception will be raised. If the provided arguments are not compatible with the camera, an exception will be raised.
@@ -617,18 +663,20 @@ camera.disconnect()
**Instantiate your robot with cameras** **Instantiate your robot with cameras**
Additionaly, you can set up your robot to work with your cameras. Additionally, you can set up your robot to work with your cameras.
Modify the following Python code with the appropriate camera names and configurations: Modify the following Python code with the appropriate camera names and configurations:
```python ```python
robot = ManipulatorRobot( robot = ManipulatorRobot(
leader_arms={"main": leader_arm}, KochRobotConfig(
follower_arms={"main": follower_arm}, leader_arms={"main": leader_arm},
calibration_dir=".cache/calibration/koch", follower_arms={"main": follower_arm},
cameras={ calibration_dir=".cache/calibration/koch",
"laptop": OpenCVCamera(0, fps=30, width=640, height=480), cameras={
"phone": OpenCVCamera(1, fps=30, width=640, height=480), "laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
}, "phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
},
)
) )
robot.connect() robot.connect()
``` ```
@@ -652,39 +700,20 @@ torch.Size([3, 480, 640])
255 255
``` ```
Also, update the following lines of the yaml file for Koch robot [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the names and configurations of your cameras: ### d. Use `control_robot.py` and our `teleoperate` function
```yaml
[...]
cameras:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
phone:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30
width: 640
height: 480
```
This file is used to instantiate your robot in all our scripts. We will explain how this works in the next section. Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next.
### d. Use `koch.yaml` and our `teleoperate` function
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the path to the robot yaml file (e.g. [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml)) and control your robot with various modes as explained next.
Try running this code to teleoperate your robot (if you dont have a camera, keep reading): Try running this code to teleoperate your robot (if you dont have a camera, keep reading):
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml --robot.type=koch \
--control.type=teleoperate
``` ```
You will see a lot of lines appearing like this one: You will see a lot of lines appearing like this one:
``` ```
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz) INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
``` ```
It contains It contains
@@ -694,21 +723,12 @@ It contains
- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`. - `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`.
- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading. - `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading.
Note: you can override any entry in the yaml file using `--robot-overrides` and the [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax. If needed, you can override the ports like this: Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`:
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml \ --robot.type=koch \
--robot-overrides \ --robot.cameras='{}' \
leader_arms.main.port=/dev/tty.usbmodem575E0031751 \ --control.type=teleoperate
follower_arms.main.port=/dev/tty.usbmodem575E0032081
```
Importantly: If you don't have any camera, you can remove them dynamically with this [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax `'~cameras'`:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/koch.yaml \
--robot-overrides \
'~cameras'
``` ```
We advise to create a new yaml file when the command becomes too long. We advise to create a new yaml file when the command becomes too long.
@@ -744,23 +764,23 @@ for _ in range(record_time_s * fps):
Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section. Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section.
### a. Use `koch.yaml` and the `record` function ### a. Use the `record` function
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities: You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities:
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of recording. 1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
2. Video streams from cameras are displayed in window so that you can verify them. 2. Video streams from cameras are displayed in window so that you can verify them.
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--push-to-hub 0` is provided). 3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again. You can also use `--force-override 1` to start recording from scratch. 4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
5. Set the flow of data recording using command line arguments: 5. Set the flow of data recording using command line arguments:
- `--warmup-time-s` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default). - `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
- `--episode-time-s` defines the number of seconds for data recording for each episode (60 seconds by default). - `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
- `--reset-time-s` defines the number of seconds for resetting the environment after each episode (60 seconds by default). - `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
- `--num-episodes` defines the number of episodes to record (50 by default). - `--control.num_episodes=50` defines the number of episodes to record (50 by default).
6. Control the flow during data recording using keyboard keys: 6. Control the flow during data recording using keyboard keys:
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording. - Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it. - Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading. - Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
7. Similarly to `teleoperate`, you can also use `--robot-path` and `--robot-overrides` to specify your robots. 7. Similarly to `teleoperate`, you can also use the command line to override anything.
Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash ```bash
@@ -771,27 +791,29 @@ Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `l
HF_USER=$(huggingface-cli whoami | head -n 1) HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER echo $HF_USER
``` ```
If you don't want to push to hub, use `--push-to-hub 0`. If you don't want to push to hub, use `--control.push_to_hub=false`.
Now run this to record 2 episodes: Now run this to record 2 episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml \ --robot.type=koch \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/koch_test \ --control.single_task="Grasp a lego block and put it in the bin." \
--tags tutorial \ --control.fps=30 \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/koch_test \
--episode-time-s 30 \ --control.tags='["tutorial"]' \
--reset-time-s 30 \ --control.warmup_time_s=5 \
--num-episodes 2 --control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
``` ```
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
You will see a lot of lines appearing like this one: You will see a lot of lines appearing like this one:
``` ```
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz) INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
@@ -803,8 +825,8 @@ It contains:
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm. - `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading. - `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm. - `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously. - `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously. - `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
Troubleshooting: Troubleshooting:
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda: - On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
@@ -824,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
echo https://huggingface.co/datasets/${HF_USER}/koch_test echo https://huggingface.co/datasets/${HF_USER}/koch_test
``` ```
### b. Advices for recording dataset ### b. Advice for recording dataset
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
@@ -842,6 +864,8 @@ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/koch_test --repo-id ${HF_USER}/koch_test
``` ```
Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub.
This will launch a local web server that looks like this: This will launch a local web server that looks like this:
<div style="text-align:center;"> <div style="text-align:center;">
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%"> <img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%">
@@ -853,11 +877,12 @@ A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/cont
To replay the first episode of the dataset you just recorded, run the following command: To replay the first episode of the dataset you just recorded, run the following command:
```bash ```bash
python lerobot/scripts/control_robot.py replay \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml \ --robot.type=koch \
--fps 30 \ --control.type=replay \
--repo-id ${HF_USER}/koch_test \ --control.fps=30 \
--episode 0 --control.repo_id=${HF_USER}/koch_test \
--control.episode=0
``` ```
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com). Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
@@ -869,50 +894,17 @@ Your robot should replicate movements similar to those you recorded. For example
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/koch_test \ --dataset.repo_id=${HF_USER}/koch_test \
policy=act_koch_real \ --policy.type=act \
env=koch_real \ --output_dir=outputs/train/act_koch_test \
hydra.run.dir=outputs/train/act_koch_test \ --job_name=act_koch_test \
hydra.job.name=act_koch_test \ --device=cuda \
device=cuda \ --wandb.enable=true
wandb.enable=true
``` ```
Let's explain it: Let's explain it:
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/koch_test`. 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
2. We provided the policy with `policy=act_koch_real`. This loads configurations from [`lerobot/configs/policy/act_koch_real.yaml`](../lerobot/configs/policy/act_koch_real.yaml). Importantly, this policy uses 2 cameras as input `laptop` and `phone`. If your dataset has different cameras, update the yaml file to account for it in the following parts: 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
```yaml
...
override_dataset_stats:
observation.images.laptop:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.phone:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
...
input_shapes:
observation.images.laptop: [3, 480, 640]
observation.images.phone: [3, 480, 640]
...
input_normalization_modes:
observation.images.laptop: mean_std
observation.images.phone: mean_std
...
```
3. We provided an environment as argument with `env=koch_real`. This loads configurations from [`lerobot/configs/env/koch_real.yaml`](../lerobot/configs/env/koch_real.yaml). It looks like
```yaml
fps: 30
env:
name: real_world
task: null
state_dim: 6
action_dim: 6
fps: ${fps}
```
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon. 4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. 5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
@@ -978,34 +970,36 @@ for _ in range(inference_time_s * fps):
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
``` ```
### a. Use `koch.yaml` and our `record` function ### a. Use our `record` function
Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation. Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation.
To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/koch.yaml \ --robot.type=koch \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/eval_koch_test \ --control.fps=30 \
--tags tutorial eval \ --control.repo_id=${HF_USER}/eval_act_koch_test \
--warmup-time-s 5 \ --control.tags='["tutorial"]' \
--episode-time-s 30 \ --control.warmup_time_s=5 \
--reset-time-s 30 \ --control.episode_time_s=30 \
--num-episodes 10 \ --control.reset_time_s=30 \
-p outputs/train/act_koch_test/checkpoints/last/pretrained_model --control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model
``` ```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_koch_test`). 1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_koch_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`).
### b. Visualize evaluation afterwards ### b. Visualize evaluation afterwards
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument: You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
```bash ```bash
python lerobot/scripts/visualize_dataset.py \ python lerobot/scripts/visualize_dataset.py \
--repo-id ${HF_USER}/eval_koch_test --repo-id ${HF_USER}/eval_act_koch_test
``` ```
## 6. Next step ## 6. Next step

View File

@@ -92,20 +92,22 @@ Serial Number = stretch-se3-3054
**Calibrate (Optional)** **Calibrate (Optional)**
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command: Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
```bash ```bash
python lerobot/scripts/control_robot.py calibrate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/stretch.yaml --robot.type=stretch \
--control.type=calibrate
``` ```
This is equivalent to running `stretch_robot_home.py` This is equivalent to running `stretch_robot_home.py`
> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first. > **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
**Teleoperate** **Teleoperate**
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation). Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
Now try out teleoperation (see above documentation to learn about the gamepad controls): Now try out teleoperation (see above documentation to learn about the gamepad controls):
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/stretch.yaml --robot.type=stretch \
--control.type=teleoperate
``` ```
This is essentially the same as running `stretch_gamepad_teleop.py` This is essentially the same as running `stretch_gamepad_teleop.py`
@@ -125,16 +127,18 @@ echo $HF_USER
Record one episode: Record one episode:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/stretch.yaml \ --robot.type=stretch \
--fps 20 \ --control.type=record \
--repo-id ${HF_USER}/stretch_test \ --control.fps=30 \
--tags stretch tutorial \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 3 \ --control.repo_id=${HF_USER}/stretch_test \
--episode-time-s 40 \ --control.tags='["tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 1 \ --control.episode_time_s=30 \
--push-to-hub 0 --control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
``` ```
> **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup). > **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup).
@@ -142,11 +146,12 @@ python lerobot/scripts/control_robot.py record \
**Replay an episode** **Replay an episode**
Now try to replay this episode (make sure the robot's initial position is the same): Now try to replay this episode (make sure the robot's initial position is the same):
```bash ```bash
python lerobot/scripts/control_robot.py replay \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/stretch.yaml \ --robot.type=stretch \
--fps 20 \ --control.type=replay \
--repo-id ${HF_USER}/stretch_test \ --control.fps=30 \
--episode 0 --control.repo_id=${HF_USER}/stretch_test \
--control.episode=0
``` ```
Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch. Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch.

View File

@@ -2,7 +2,7 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
## Setup ## Setup
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer. Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
## Install LeRobot ## Install LeRobot
@@ -51,16 +51,18 @@ Teleoperation consists in manually operating the leader arms to move the followe
By running the following code, you can start your first **SAFE** teleoperation: By running the following code, you can start your first **SAFE** teleoperation:
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/aloha.yaml \ --robot.type=aloha \
--robot-overrides max_relative_target=5 --robot.max_relative_target=5 \
--control.type=teleoperate
``` ```
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line: By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
```bash ```bash
python lerobot/scripts/control_robot.py teleoperate \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/aloha.yaml \ --robot.type=aloha \
--robot-overrides max_relative_target=null --robot.max_relative_target=null \
--control.type=teleoperate
``` ```
## Record a dataset ## Record a dataset
@@ -80,27 +82,29 @@ echo $HF_USER
Record 2 episodes and upload your dataset to the hub: Record 2 episodes and upload your dataset to the hub:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/aloha.yaml \ --robot.type=aloha \
--robot-overrides max_relative_target=null \ --robot.max_relative_target=null \
--fps 30 \ --control.type=record \
--repo-id ${HF_USER}/aloha_test \ --control.fps=30 \
--tags aloha tutorial \ --control.single_task="Grasp a lego block and put it in the bin." \
--warmup-time-s 5 \ --control.repo_id=${HF_USER}/aloha_test \
--episode-time-s 40 \ --control.tags='["tutorial"]' \
--reset-time-s 10 \ --control.warmup_time_s=5 \
--num-episodes 2 \ --control.episode_time_s=30 \
--push-to-hub 1 --control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
``` ```
## Visualize a dataset ## Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash ```bash
echo ${HF_USER}/aloha_test echo ${HF_USER}/aloha_test
``` ```
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with: If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
```bash ```bash
python lerobot/scripts/visualize_dataset_html.py \ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/aloha_test --repo-id ${HF_USER}/aloha_test
@@ -109,16 +113,17 @@ python lerobot/scripts/visualize_dataset_html.py \
## Replay an episode ## Replay an episode
**/!\ FOR SAFETY, READ THIS /!\** **/!\ FOR SAFETY, READ THIS /!\**
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot-overrides max_relative_target=5` to your command line as explained above. Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
Now try to replay the first episode on your robot: Now try to replay the first episode on your robot:
```bash ```bash
python lerobot/scripts/control_robot.py replay \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/aloha.yaml \ --robot.type=aloha \
--robot-overrides max_relative_target=null \ --robot.max_relative_target=null \
--fps 30 \ --control.type=replay \
--repo-id ${HF_USER}/aloha_test \ --control.fps=30 \
--episode 0 --control.repo_id=${HF_USER}/aloha_test \
--control.episode=0
``` ```
## Train a policy ## Train a policy
@@ -126,49 +131,51 @@ python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/aloha_test \ --dataset.repo_id=${HF_USER}/aloha_test \
policy=act_aloha_real \ --policy.type=act \
env=aloha_real \ --output_dir=outputs/train/act_aloha_test \
hydra.run.dir=outputs/train/act_aloha_test \ --job_name=act_aloha_test \
hydra.job.name=act_aloha_test \ --device=cuda \
device=cuda \ --wandb.enable=true
wandb.enable=true
``` ```
Let's explain it: Let's explain it:
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/aloha_test`. 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy=act_aloha_real`. This loads configurations from [`lerobot/configs/policy/act_aloha_real.yaml`](../lerobot/configs/policy/act_aloha_real.yaml). Importantly, this policy uses 4 cameras as input `cam_right_wrist`, `cam_left_wrist`, `cam_high`, and `cam_low`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity. 4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
4. We provided `device=cuda` since we are training on a Nvidia GPU.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. 5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`. Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
## Evaluate your policy ## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record \ python lerobot/scripts/control_robot.py \
--robot-path lerobot/configs/robot/aloha.yaml \ --robot.type=aloha \
--robot-overrides max_relative_target=null \ --control.type=record \
--fps 30 \ --control.fps=30 \
--repo-id ${HF_USER}/eval_act_aloha_test \ --control.single_task="Grasp a lego block and put it in the bin." \
--tags aloha tutorial eval \ --control.repo_id=${HF_USER}/eval_act_aloha_test \
--warmup-time-s 5 \ --control.tags='["tutorial"]' \
--episode-time-s 40 \ --control.warmup_time_s=5 \
--reset-time-s 10 \ --control.episode_time_s=30 \
--num-episodes 10 \ --control.reset_time_s=30 \
--num-image-writer-processes 1 \ --control.num_episodes=10 \
-p outputs/train/act_aloha_test/checkpoints/last/pretrained_model --control.push_to_hub=true \
--control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \
--control.num_image_writer_processes=1
``` ```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_aloha_test`). 1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_aloha_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
3. We use `--num-image-writer-processes 1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--num-image-writer-processes`. 3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More ## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination. Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`. If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.

View File

@@ -1,87 +0,0 @@
# @package _global_
# Change the seed to match what PushT eval uses
# (to avoid evaluating on seeds used for generating the training data).
seed: 100000
# Change the dataset repository to the PushT one.
dataset_repo_id: lerobot/pusht
override_dataset_stats:
observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.image: mean_std
# Use min_max normalization just because it's more standard.
observation.state: min_max
output_normalization_modes:
# Use min_max normalization just because it's more standard.
action: min_max
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,70 +0,0 @@
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
Let's get started!
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
```diff
override_dataset_stats:
- observation.images.top:
+ observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
policy:
input_shapes:
- observation.images.top: [3, 480, 640]
+ observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
input_normalization_modes:
- observation.images.top: mean_std
+ observation.image: mean_std
observation.state: min_max
output_normalization_modes:
action: min_max
```
Here we've accounted for the following:
- PushT uses "observation.image" for its image key.
- PushT provides smaller images.
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
```bash
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
```
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
```bash
python lerobot/scripts/train.py policy=act_pusht env=pusht
```
Notice that this is much the same as the command that failed at the start of the tutorial, only:
- Now we are using `policy=act_pusht` to point to our new configuration file.
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
Hurrah! You're now training ACT for the PushT environment.
---
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
Happy coding! 🤗

View File

@@ -9,76 +9,82 @@ on the target environment, whether that be in simulation or the real world.
""" """
import math import math
from pathlib import Path
import torch import torch
from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda")
# Download the diffusion policy for pusht environment def main():
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) device = torch.device("cuda")
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) # Download the diffusion policy for pusht environment
policy.eval() pretrained_policy_path = "lerobot/diffusion_pusht"
policy.to(device) # OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
# Set up the dataset. policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
delta_timestamps = { policy.eval()
# Load the previous image and state at -0.1 seconds before current frame, policy.to(device)
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10% of episodes of the dataset as a validation set. # Set up the dataset.
# - Load dataset metadata delta_timestamps = {
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") # Load the previous image and state at -0.1 seconds before current frame,
# - Calculate train and val episodes # then load current image and state corresponding to 0.0 second.
total_episodes = dataset_metadata.total_episodes "observation.image": [-0.1, 0.0],
episodes = list(range(dataset_metadata.total_episodes)) "observation.state": [-0.1, 0.0],
num_train_episodes = math.floor(total_episodes * 90 / 100) # Load the previous action (-0.1), the next action to be executed (0.0),
train_episodes = episodes[:num_train_episodes] # and 14 future actions with a 0.1 seconds spacing. All these actions will be
val_episodes = episodes[num_train_episodes:] # used to calculate the loss.
print(f"Number of episodes in full dataset: {total_episodes}") "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}") }
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Create dataloader for evaluation. # Load the last 10% of episodes of the dataset as a validation set.
val_dataloader = torch.utils.data.DataLoader( # - Load dataset metadata
val_dataset, dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
num_workers=4, # - Calculate train and val episodes
batch_size=64, total_episodes = dataset_metadata.total_episodes
shuffle=False, episodes = list(range(dataset_metadata.total_episodes))
pin_memory=device != torch.device("cpu"), num_train_episodes = math.floor(total_episodes * 90 / 100)
drop_last=False, train_episodes = episodes[:num_train_episodes]
) val_episodes = episodes[num_train_episodes:]
print(f"Number of episodes in full dataset: {total_episodes}")
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
train_dataset = LeRobotDataset(
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Run validation loop. # Create dataloader for evaluation.
loss_cumsum = 0 val_dataloader = torch.utils.data.DataLoader(
n_examples_evaluated = 0 val_dataset,
for batch in val_dataloader: num_workers=4,
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} batch_size=64,
output_dict = policy.forward(batch) shuffle=False,
pin_memory=device != torch.device("cpu"),
drop_last=False,
)
loss_cumsum += output_dict["loss"].item() # Run validation loop.
n_examples_evaluated += batch["index"].shape[0] loss_cumsum = 0
n_examples_evaluated = 0
for batch in val_dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
loss, _ = policy.forward(batch)
# Calculate the average loss over the validation set. loss_cumsum += loss.item()
average_loss = loss_cumsum / n_examples_evaluated n_examples_evaluated += batch["index"].shape[0]
print(f"Average loss on validation set: {average_loss:.4f}") # Calculate the average loss over the validation set.
average_loss = loss_cumsum / n_examples_evaluated
print(f"Average loss on validation set: {average_loss:.4f}")
if __name__ == "__main__":
main()

View File

@@ -2,9 +2,10 @@ import shutil
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import torch from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
@@ -44,7 +45,7 @@ PUSHT_FEATURES = {
"dtype": None, "dtype": None,
"shape": (3, 96, 96), "shape": (3, 96, 96),
"names": [ "names": [
"channel", "channels",
"height", "height",
"width", "width",
], ],
@@ -89,9 +90,9 @@ def calculate_coverage(zarr_data):
num_frames = len(block_pos) num_frames = len(block_pos)
coverage = np.zeros((num_frames,)) coverage = np.zeros((num_frames,), dtype=np.float32)
# 8 keypoints with 2 coords each # 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16)) keypoints = np.zeros((num_frames, 16), dtype=np.float32)
# Set x, y, theta (in radians) # Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4]) goal_pos_angle = np.array([256, 256, np.pi / 4])
@@ -117,7 +118,7 @@ def calculate_coverage(zarr_data):
intersection_area = goal_geom.intersection(block_geom).area intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten()) keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
return coverage, keypoints return coverage, keypoints
@@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if mode not in ["video", "image", "keypoints"]: if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode) raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists(): if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id) shutil.rmtree(HF_LEROBOT_HOME / repo_id)
if not raw_dir.exists(): if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
@@ -148,6 +149,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
action = zarr_data["action"][:] action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c) image = zarr_data["img"] # (b, h, w, c)
if image.dtype == np.float32 and image.max() == np.float32(255):
# HACK: images are loaded as float32 but they actually encode uint8 data
image = image.astype(np.uint8)
episode_data_index = { episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"], "to": zarr_data.meta["episode_ends"],
@@ -175,28 +180,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
for frame_idx in range(num_frames): for frame_idx in range(num_frames):
i = from_idx + frame_idx i = from_idx + frame_idx
idx = i + (frame_idx < num_frames - 1)
frame = { frame = {
"action": torch.from_numpy(action[i]), "action": action[i],
# Shift reward and success by +1 until the last item of the episode # Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)], "next.reward": reward[idx : idx + 1],
"next.success": success[i + (frame_idx < num_frames - 1)], "next.success": success[idx : idx + 1],
"task": PUSHT_TASK,
} }
frame["observation.state"] = torch.from_numpy(agent_pos[i]) frame["observation.state"] = agent_pos[i]
if mode == "keypoints": if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i]) frame["observation.environment_state"] = keypoints[i]
else: else:
frame["observation.image"] = torch.from_numpy(image[i]) frame["observation.image"] = image[i]
dataset.add_frame(frame) dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK) dataset.save_episode()
dataset.consolidate()
if push_to_hub: if push_to_hub:
dataset.push_to_hub() dataset.push_to_hub()
hub_api = HfApi()
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
if __name__ == "__main__": if __name__ == "__main__":
@@ -218,5 +225,5 @@ if __name__ == "__main__":
main(raw_dir, repo_id=repo_id, mode=mode) main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it # Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) # dataset = LeRobotDataset(repo_id=repo_id)
# breakpoint() # breakpoint()

View File

@@ -58,7 +58,6 @@ available_tasks_per_env = {
], ],
"pusht": ["PushT-v0"], "pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"], "xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
} }
available_envs = list(available_tasks_per_env.keys()) available_envs = list(available_tasks_per_env.keys())
@@ -86,23 +85,6 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image", "lerobot/xarm_push_medium_replay_image",
], ],
"dora_aloha_real": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
} }
available_real_world_datasets = [ available_real_world_datasets = [
@@ -221,7 +203,6 @@ available_policies_per_env = {
"xarm": ["tdmpc"], "xarm": ["tdmpc"],
"koch_real": ["act_koch_real"], "koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"], "aloha_real": ["act_aloha_real"],
"dora_aloha_real": ["act_aloha_real"],
} }
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@@ -0,0 +1,32 @@
# keys
import os
from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)

View File

@@ -0,0 +1,54 @@
import packaging.version
V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)

View File

@@ -13,202 +13,164 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy import numpy as np
from math import ceil
import einops from lerobot.common.datasets.utils import load_image_as_numpy
import torch
import tqdm
def get_stats_einops_patterns(dataset, num_workers=0): def estimate_num_samples(
"""These einops patterns will be used to aggregate batches and compute statistics. dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
Lower the power for less number of samples.
Note: We assume the images are in channel first format For default arguments, we have:
- from 1 to ~500, num_samples=100
- at 1000, num_samples=177
- at 2000, num_samples=299
- at 5000, num_samples=594
- at 10000, num_samples=1000
- at 20000, num_samples=1681
""" """
if dataset_len < min_num_samples:
min_num_samples = dataset_len
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {} def sample_indices(data_len: int) -> list[int]:
num_samples = estimate_num_samples(data_len)
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
# if isinstance(feats_type, (VideoFrame, Image)): def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
if key in dataset.meta.camera_keys: _, height, width = img.shape
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1] if max(width, height) < max_size_threshold:
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" # no downsampling needed
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" return img
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1" downsample_factor = int(width / target_size) if width > height else int(height / target_size)
elif batch[key].ndim == 2: return img[:, ::downsample_factor, ::downsample_factor]
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1" def sample_images(image_paths: list[str]) -> np.ndarray:
sampled_indices = sample_indices(len(image_paths))
images = None
for i, idx in enumerate(sampled_indices):
path = image_paths[idx]
# we load as uint8 to reduce memory usage
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else: else:
raise ValueError(f"{key}, {batch[key].shape}") ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
return stats_patterns ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None): def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" for i in range(len(stats_list)):
if max_num_samples is None: for fkey in stats_list[i]:
max_num_samples = len(dataset) for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
# for more info on why we need to set the same number of workers, see `load_from_videos` raise ValueError(
stats_patterns = get_stats_einops_patterns(dataset, num_workers) f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
# mean and std will be computed incrementally while max and min will track the running value. if v.ndim == 0:
mean, std, max, min = {}, {}, {}, {} raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
for key in stats_patterns: if k == "count" and v.shape != (1,):
mean[key] = torch.tensor(0.0).float() raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
std[key] = torch.tensor(0.0).float() if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
max[key] = torch.tensor(-float("inf")).float() raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. """Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
total_count = counts.sum(axis=0)
The final stats will have the union of all data keys from each of the datasets. # Prepare weighted mean by matching number of dimensions
while counts.ndim < means.ndim:
counts = np.expand_dims(counts, axis=-1)
The final stats will have the union of all data keys from each of the datasets. For instance: # Compute the weighted mean
- new_max = max(max_dataset_0, max_dataset_1, ...) weighted_means = means * counts
total_mean = weighted_means.sum(axis=0) / total_count
# Compute the variance using the parallel algorithm
delta_means = means - total_mean
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.
For instance:
- new_min = min(min_dataset_0, min_dataset_1, ...) - new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data) - new_max = max(max_dataset_0, max_dataset_1, ...)
- new_mean = (mean of all data, weighted by counts)
- new_std = (std of all data) - new_std = (std of all data)
""" """
data_keys = set()
for dataset in ls_datasets: _assert_type_and_shape(stats_list)
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys} data_keys = {key for stats in stats_list for key in stats}
for data_key in data_keys: aggregated_stats = {key: {} for key in data_keys}
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)` for key in data_keys:
stats[data_key][stat_key] = einops.reduce( stats_with_key = [stats[key] for stats in stats_list if key in stats]
torch.stack( aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0, return aggregated_stats
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
)
return stats

View File

@@ -14,103 +14,105 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from pprint import pformat
import torch import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import (
from lerobot.common.datasets.transforms import get_image_transforms LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
def resolve_delta_timestamps(cfg): def resolve_delta_timestamps(
"""Resolves delta_timestamps config key (in-place) by using `eval`. cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
the data type of its values).
"""
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args: Args:
cfg: A Hydra config as per the LeRobot config scheme. cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
split: Select the data subset used to create an instance of LeRobotDataset. ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train". delta_timestamps against.
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns: Returns:
The LeRobotDataset. dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
{
"observation.state": [-0.04, -0.02, 0]
"observation.action": [-0.02, 0, 0.02]
}
returns `None` if the the resulting dict is empty.
""" """
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): delta_timestamps = {}
raise ValueError( for key in ds_meta.features:
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " if key == "next.reward" and cfg.reward_delta_indices is not None:
"strings to load multiple datasets." delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
delta_timestamps = None
return delta_timestamps
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
Returns:
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
) )
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
resolve_delta_timestamps(cfg)
image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, cfg.dataset.repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"), root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend, revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
) )
else: else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset( dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, cfg.dataset.repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"), # TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend, video_backend=cfg.dataset.video_backend,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
) )
if cfg.get("override_dataset_stats"): if cfg.dataset.use_imagenet_stats:
for key, stats_dict in cfg.override_dataset_stats.items(): for key in dataset.meta.camera_keys:
for stats_type, listconfig in stats_dict.items(): for stats_type, stats in IMAGENET_STATS.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset return dataset

View File

@@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
return wrapper return wrapper
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image: def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images # TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]: if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C) # Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0) image_array = image_array.transpose(1, 2, 0)
elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
)
if image_array.dtype != np.uint8: if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data if range_check:
image_array = np.clip(image_array, 0, 1) max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)
image_array = (image_array * 255).astype(np.uint8) image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array) return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try: try:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
img = image_array_to_image(image) img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
img = image img = image
else: else:

View File

@@ -14,49 +14,54 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
import shutil import shutil
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
import datasets import datasets
import numpy as np import numpy as np
import packaging.version
import PIL.Image import PIL.Image
import torch import torch
import torch.utils import torch.utils
from datasets import load_dataset from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, snapshot_download, upload_folder from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH,
TASKS_PATH, TASKS_PATH,
append_jsonlines, append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
check_version_compatibility, check_version_compatibility,
create_branch,
create_empty_dataset_info, create_empty_dataset_info,
create_lerobot_dataset_card, create_lerobot_dataset_card,
embed_images,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_features_from_robot, get_features_from_robot,
get_hf_features_from_features, get_hf_features_from_features,
get_hub_safe_version, get_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
is_valid_version,
load_episodes, load_episodes,
load_episodes_stats,
load_info, load_info,
load_stats, load_stats,
load_tasks, load_tasks,
serialize_dict, validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
write_json, write_json,
write_parquet,
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
@@ -66,9 +71,7 @@ from lerobot.common.datasets.video_utils import (
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v2.1"
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDatasetMetadata: class LeRobotDatasetMetadata:
@@ -76,19 +79,36 @@ class LeRobotDatasetMetadata:
self, self,
repo_id: str, repo_id: str,
root: str | Path | None = None, root: str | Path | None = None,
local_files_only: bool = False, revision: str | None = None,
force_cache_sync: bool = False,
): ):
self.repo_id = repo_id self.repo_id = repo_id
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id self.revision = revision if revision else CODEBASE_VERSION
self.local_files_only = local_files_only self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
# Load metadata try:
(self.root / "meta").mkdir(exist_ok=True, parents=True) if force_cache_sync:
self.pull_from_repo(allow_patterns="meta/") raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def load_metadata(self):
self.info = load_info(self.root) self.info = load_info(self.root)
self.stats = load_stats(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo( def pull_from_repo(
self, self,
@@ -98,21 +118,16 @@ class LeRobotDatasetMetadata:
snapshot_download( snapshot_download(
self.repo_id, self.repo_id,
repo_type="dataset", repo_type="dataset",
revision=self._hub_version, revision=self.revision,
local_dir=self.root, local_dir=self.root,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
) )
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property @property
def _version(self) -> str: def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset.""" """Codebase version used to create this dataset."""
return self.info["codebase_version"] return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path: def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
@@ -202,54 +217,65 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk.""" """Max number of episodes per chunk."""
return self.info["chunks_size"] return self.info["chunks_size"]
@property def get_task_index(self, task: str) -> int | None:
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
""" """
Given a task in natural language, returns its task_index if the task already exists in the dataset, Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index. otherwise return None.
""" """
task_index = self.task_to_task_index.get(task, None) return self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: def add_task(self, task: str):
"""
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
self.tasks[task_index] = task
self.info["total_tasks"] += 1
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
) -> None:
self.info["total_episodes"] += 1 self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index) chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks: if chunk >= self.total_chunks:
self.info["total_chunks"] += 1 self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys) self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / INFO_PATH) if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = { episode_dict = {
"episode_index": episode_index, "episode_index": episode_index,
"tasks": [task], "tasks": episode_tasks,
"length": episode_length, "length": episode_length,
} }
self.episodes.append(episode_dict) self.episodes[episode_index] = episode_dict
append_jsonlines(episode_dict, self.root / EPISODES_PATH) write_episode(episode_dict, self.root)
# TODO(aliberts): refactor stats in save_episodes self.episodes_stats[episode_index] = episode_stats
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling) write_episode_stats(episode_index, episode_stats, self.root)
# ep_stats = serialize_dict(ep_stats)
# append_jsonlines(ep_stats, self.root / STATS_PATH)
def write_video_info(self) -> None: def update_video_info(self) -> None:
""" """
Warning: this function writes info from first episode videos, implicitly assuming that all videos have Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
@@ -259,8 +285,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path) self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def __repr__(self): def __repr__(self):
feature_keys = list(self.features) feature_keys = list(self.features)
return ( return (
@@ -286,7 +310,7 @@ class LeRobotDatasetMetadata:
"""Creates metadata for a LeRobotDataset.""" """Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False) obj.root.mkdir(parents=True, exist_ok=False)
@@ -306,12 +330,21 @@ class LeRobotDatasetMetadata:
# TODO(aliberts, rcadene): implement sanity check for features # TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, [] # check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos: if len(obj.video_keys) > 0 and not use_videos:
raise ValueError() raise ValueError()
write_json(obj.info, obj.root / INFO_PATH) write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True obj.revision = None
return obj return obj
@@ -324,8 +357,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True, download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None, video_backend: str | None = None,
): ):
""" """
@@ -335,7 +369,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On your local disk in the 'root' folder. This is typically the case when you recorded your - On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no with 'root' will load your dataset directly from disk. This can happen while you're offline (no
internet connection), in that case, use local_files_only=True. internet connection).
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
@@ -355,7 +389,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- info contains various information about the dataset like shapes, keys, fps etc. - info contains various information about the dataset like shapes, keys, fps etc.
- stats stores the dataset statistics of the different modalities for normalization - stats stores the dataset statistics of the different modalities for normalization
- tasks contains the prompts for each task of the dataset, which can be used for - tasks contains the prompts for each task of the dataset, which can be used for
task-conditionned training. task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files. - hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files. - videos (optional) from which frames are loaded to be synchronous with data from parquet files.
@@ -417,24 +451,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4. multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.root = Path(root) if root else LEROBOT_HOME / repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav" self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None self.delta_indices = None
self.local_files_only = local_files_only
# Unused attributes # Unused attributes
self.image_writer = None self.image_writer = None
@@ -443,64 +481,86 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True) self.root.mkdir(exist_ok=True, parents=True)
# Load metadata # Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
# Check version )
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Load actual data # Load actual data
self.download_episodes(download_videos) try:
self.hf_dataset = self.load_hf_dataset() if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps # Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub( def push_to_hub(
self, self,
branch: str | None = None,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", license: str | None = "apache-2.0",
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs, **card_kwargs,
) -> None: ) -> None:
if not self.consolidated:
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"] ignore_patterns = ["images/"]
if not push_videos: if not push_videos:
ignore_patterns.append("videos/") ignore_patterns.append("videos/")
create_repo( hub_api = HfApi()
hub_api.create_repo(
repo_id=self.repo_id, repo_id=self.repo_id,
private=private, private=private,
repo_type="dataset", repo_type="dataset",
exist_ok=True, exist_ok=True,
) )
if branch:
hub_api.create_branch(
repo_id=self.repo_id,
branch=branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
upload_folder( upload_kwargs = {
repo_id=self.repo_id, "repo_id": self.repo_id,
folder_path=self.root, "folder_path": self.root,
repo_type="dataset", "repo_type": "dataset",
ignore_patterns=ignore_patterns, "revision": branch,
) "allow_patterns": allow_patterns,
card = create_lerobot_dataset_card( "ignore_patterns": ignore_patterns,
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs }
) if upload_large_folder:
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset") hub_api.upload_large_folder(**upload_kwargs)
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
def pull_from_repo( def pull_from_repo(
self, self,
@@ -510,11 +570,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download( snapshot_download(
self.repo_id, self.repo_id,
repo_type="dataset", repo_type="dataset",
revision=self.meta._hub_version, revision=self.revision,
local_dir=self.root, local_dir=self.root,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
) )
def download_episodes(self, download_videos: bool = True) -> None: def download_episodes(self, download_videos: bool = True) -> None:
@@ -528,17 +587,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None files = None
ignore_patterns = None if download_videos else "videos/" ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None: if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] files = self.get_episodes_file_paths()
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in self.episodes
]
files += video_files
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in episodes
]
fpaths += video_files
return fpaths
def load_hf_dataset(self) -> datasets.Dataset: def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None: if self.episodes is None:
@@ -550,7 +615,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): hf_dataset.set_format("torch") # TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@property @property
@@ -617,7 +690,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys if key not in self.meta.video_keys
} }
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -647,8 +720,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None query_indices = None
if self.delta_indices is not None: if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx query_indices, padding = self._get_query_indices(idx, ep_idx)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices) query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding} item = {**item, **padding}
for key, val in query_result.items(): for key, val in query_result.items():
@@ -665,6 +737,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
for cam in image_keys: for cam in image_keys:
item[cam] = self.image_transforms(item[cam]) item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
return item return item
def __repr__(self): def __repr__(self):
@@ -680,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_episode_buffer(self, episode_index: int | None = None) -> dict: def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return { ep_buffer = {}
"size": 0, # size and task are special cases that are not in self.features
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, ep_buffer["size"] = 0
} ep_buffer["task"] = []
for key in self.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format( fpath = DEFAULT_IMAGE_PATH.format(
@@ -705,25 +784,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called. then needs to be called.
""" """
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, # Convert torch to numpy if needed
# check the dtype and shape matches, etc. for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
validate_frame(frame, self.features)
if self.episode_buffer is None: if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
# Add frame features to episode_buffer
for key in frame: for key in frame:
if key not in self.features: if key == "task":
raise ValueError(key) # Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if self.features[key]["dtype"] not in ["image", "video"]: if key not in self.features:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] raise ValueError(
self.episode_buffer[key].append(item) f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
elif self.features[key]["dtype"] in ["image", "video"]: )
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path( img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
) )
@@ -731,80 +820,95 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path) self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path)) self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1 self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None: def save_episode(self, episode_data: dict | None = None) -> None:
""" """
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on This will save to disk the current episode in self.episode_buffer.
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise, Args:
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
time for video encoding. save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
""" """
if not episode_data: if not episode_data:
episode_buffer = self.episode_buffer episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size") episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"] episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes in the dataset. This is not supported for now."
)
if episode_length == 0: episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
raise ValueError( episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
task_index = self.meta.get_task_index(task) # Add new tasks to the tasks dictionary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
if not set(episode_buffer.keys()) == set(self.features): # Given tasks in natural language, find their corresponding task indices
raise ValueError() episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items(): for key, ft in self.features.items():
if key == "index": # index, episode_index, task_index are already processed above, and image and video
episode_buffer[key] = np.arange( # are processed separately by storing image path and frame info as meta data
self.meta.total_frames, self.meta.total_frames + episode_length if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
)
elif key == "episode_index":
episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
continue continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: episode_buffer[key] = np.stack(episode_buffer[key])
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
episode_buffer[key] = np.stack(episode_buffer[key])
else:
raise ValueError(key)
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, task, task_index) if len(self.meta.video_keys) > 0:
if encode_videos and len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index) video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys: for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key] episode_buffer[key] = video_paths[key]
# `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features} episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path) ep_dataset.to_parquet(ep_data_path)
def clear_episode_buffer(self) -> None: def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
@@ -833,7 +937,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None: def stop_image_writer(self) -> None:
""" """
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
""" """
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.stop() self.image_writer.stop()
@@ -873,38 +977,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.meta.video_keys) > 0:
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
if run_compute_stats:
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning(
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
)
@classmethod @classmethod
def create( def create(
cls, cls,
@@ -933,7 +1005,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
obj.repo_id = obj.meta.repo_id obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only obj.revision = None
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = None obj.image_writer = None
@@ -943,14 +1015,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer() obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.episodes = None obj.episodes = None
obj.hf_dataset = None obj.hf_dataset = obj.create_hf_dataset()
obj.image_transforms = None obj.image_transforms = None
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
@@ -975,12 +1041,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None, tolerances_s: dict | None = None,
download_videos: bool = True, download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None, video_backend: str | None = None,
): ):
super().__init__() super().__init__()
self.repo_ids = repo_ids self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class. # are handled by this class.
@@ -993,7 +1058,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id], tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos, download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend, video_backend=video_backend,
) )
for repo_id in repo_ids for repo_id in repo_ids
@@ -1021,7 +1085,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets) # TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property @property
def repo_id_to_index(self): def repo_id_to_index(self):

View File

@@ -1,56 +0,0 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
continue
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
```

View File

@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1, stacklevel=1,
) )
# Send warning if raw_dir isn't well formated # Send warning if raw_dir isn't well formatted
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn( warnings.warn(
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that

View File

@@ -68,11 +68,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
modality_df, modality_df,
on="timestamp_utc", on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest". # matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart. # are too far apart.
direction="nearest", direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"), tolerance=pd.Timedelta(f"{1 / fps} seconds"),
) )
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
df = df[df["episode_index"] != -1] df = df[df["episode_index"] != -1]
@@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
videos_dir.parent.mkdir(parents=True, exist_ok=True) videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir.symlink_to((raw_dir / "videos").absolute()) videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated # sanity check the video paths are well formatted
for key in df: for key in df:
if "observation.images." not in key: if "observation.images." not in key:
continue continue
@@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}] # it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in df[key].values] data_dict[key] = [video_frame[0] for video_frame in df[key].values]
# sanity check the video path is well formated # sanity check the video path is well formatted
video_path = videos_dir.parent / data_dict[key][0]["path"] video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists(): if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}") raise ValueError(f"Video file not found in {video_path}")

View File

@@ -17,7 +17,7 @@
For all datasets in the RLDS format. For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets. For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script. NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
Example: Example:
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \

View File

@@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
from typing import Any, Callable, Dict, Sequence from dataclasses import dataclass, field
from typing import Any, Callable, Sequence
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
@@ -65,6 +66,8 @@ class RandomSubsetApply(Transform):
self.n_subset = n_subset self.n_subset = n_subset
self.random_order = random_order self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1 needs_unpacking = len(inputs) > 1
@@ -72,9 +75,9 @@ class RandomSubsetApply(Transform):
if not self.random_order: if not self.random_order:
selected_indices = selected_indices.sort().values selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices] self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms: for transform in self.selected_transforms:
outputs = transform(*inputs) outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,) inputs = outputs if needs_unpacking else (outputs,)
@@ -129,69 +132,118 @@ class SharpnessJitter(Transform):
return float(sharpness[0]), float(sharpness[1]) return float(sharpness[0]), float(sharpness[1])
def _generate_value(self, left: float, right: float) -> float: def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
return torch.empty(1).uniform_(left, right).item() sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms( @dataclass
brightness_weight: float = 1.0, class ImageTransformConfig:
brightness_min_max: tuple[float, float] | None = None, """
contrast_weight: float = 1.0, For each transform, the following parameters are available:
contrast_min_max: tuple[float, float] | None = None, weight: This represents the multinomial probability (with no replacement)
saturation_weight: float = 1.0, used for sampling the transform. If the sum of the weights is not 1,
saturation_min_max: tuple[float, float] | None = None, they will be normalized.
hue_weight: float = 1.0, type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
hue_min_max: tuple[float, float] | None = None, custom transform defined here.
sharpness_weight: float = 1.0, kwargs: Lower & upper bound respectively used for sampling the transform's parameter
sharpness_min_max: tuple[float, float] | None = None, (following uniform distribution) when it's applied.
max_num_transforms: int | None = None, """
random_order: bool = False,
):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
)
if weight < 0.0:
raise ValueError(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)
check_value("brightness", brightness_weight, brightness_min_max) weight: float = 1.0
check_value("contrast", contrast_weight, contrast_min_max) type: str = "Identity"
check_value("saturation", saturation_weight, saturation_min_max) kwargs: dict[str, Any] = field(default_factory=dict)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
n_subset = len(transforms) @dataclass
if max_num_transforms is not None: class ImageTransformsConfig:
n_subset = min(n_subset, max_num_transforms) """
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
if n_subset == 0: # Set this flag to `true` to enable transforms during training
return v2.Identity() enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
else: else:
# TODO(rcadene, aliberts): add v2.ToDtype float16? raise ValueError(f"Transform '{cfg.type}' is not valid.")
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)

View File

@@ -13,32 +13,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import importlib.resources import importlib.resources
import json import json
import logging import logging
import textwrap from collections.abc import Iterator
from itertools import accumulate from itertools import accumulate
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from types import SimpleNamespace
from typing import Any from typing import Any
import datasets import datasets
import jsonlines import jsonlines
import numpy as np import numpy as np
import pyarrow.compute as pc import packaging.version
import torch import torch
from datasets.table import embed_table_storage from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json" INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl" EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json" STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl" TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
@@ -96,18 +106,39 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict return outdict
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
return getter
for key in split_keys[1:]:
getter = getter[key]
return getter
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()} serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict) return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet # Embed image bytes into the table before saving to parquet
format = dataset.format format = dataset.format
dataset = dataset.with_format("arrow") dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False) dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format) dataset = dataset.with_format(**format)
dataset.to_parquet(fpath) return dataset
def load_json(fpath: Path) -> Any: def load_json(fpath: Path) -> Any:
@@ -138,6 +169,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
writer.write(data) writer.write(data)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict: def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH) info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values(): for ft in info["features"].values():
@@ -145,29 +180,76 @@ def load_info(local_dir: Path) -> dict:
return info return info
def load_stats(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path):
if not (local_dir / STATS_PATH).exists(): serialized_stats = serialize_dict(stats)
return None write_json(serialized_stats, local_dir / STATS_PATH)
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats) return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict: def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict: def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH) episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray: def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB") img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype) img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W) if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1)) img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype: if np.issubdtype(dtype, np.floating):
img_array /= 255.0 img_array /= 255.0
return img_array return img_array
@@ -186,77 +268,82 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None: elif first_item is None:
pass pass
else: else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]] items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict return items_dict
def _get_major_minor(version: str) -> tuple[int]: def is_valid_version(version: str) -> bool:
split = version.strip("v").split(".") try:
return int(split[0]), int(split[1]) packaging.version.parse(version)
return True
except packaging.version.InvalidVersion:
class BackwardCompatibilityError(Exception): return False
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def check_version_compatibility( def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True repo_id: str,
version_to_check: str | packaging.version.Version,
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None: ) -> None:
current_major, _ = _get_major_minor(current_version) v_check = (
major_to_check, _ = _get_major_minor(version_to_check) packaging.version.parse(version_to_check)
if major_to_check < current_major and enforce_breaking_major: if not isinstance(version_to_check, packaging.version.Version)
raise BackwardCompatibilityError(repo_id, version_to_check) else version_to_check
elif float(version_to_check.strip("v")) < float(current_version.strip("v")): )
logging.warning( v_current = (
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the packaging.version.parse(current_version)
codebase. The current codebase version is {current_version}. You should be fine since if not isinstance(current_version, packaging.version.Version)
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on else current_version
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", )
) if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_hub_safe_version(repo_id: str, version: str) -> str: def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi() api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches] repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
if version not in branches: repo_versions = []
num_version = float(version.strip("v")) for ref in repo_refs:
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")] with contextlib.suppress(packaging.version.InvalidVersion):
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions): repo_versions.append(packaging.version.parse(ref))
raise BackwardCompatibilityError(repo_id, version)
logging.warning( return repo_versions
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on """
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", Returns the version if available on repo or the latest compatible one.
) Otherwise, will throw a `CompatibilityError`.
if "main" not in branches: """
raise ValueError(f"Version 'main' not found on {repo_id}") target_version = (
return "main" packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
else: )
return version hub_versions = get_repo_versions(repo_id)
if target_version in hub_versions:
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
if lower_major:
raise BackwardCompatibilityError(repo_id, max(lower_major))
upper_versions = [v for v in hub_versions if v > target_version]
assert len(upper_versions) > 0
raise ForwardCompatibilityError(repo_id, min(upper_versions))
def get_hf_features_from_features(features: dict) -> datasets.Features: def get_hf_features_from_features(features: dict) -> datasets.Features:
@@ -268,11 +355,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Image() hf_features[key] = datasets.Image()
elif ft["shape"] == (1,): elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"]) hf_features[key] = datasets.Value(dtype=ft["dtype"])
else: elif len(ft["shape"]) == 1:
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence( hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
) )
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features) return datasets.Features(hf_features)
@@ -287,6 +383,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
elif key.startswith("observation"):
type = FeatureType.STATE
elif key == "action":
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def create_empty_dataset_info( def create_empty_dataset_info(
codebase_version: str, codebase_version: str,
fps: int, fps: int,
@@ -312,88 +439,85 @@ def create_empty_dataset_info(
def get_episode_data_index( def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None: if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
cumulative_lenghts = list(accumulate(episode_lengths.values())) cumulative_lengths = list(accumulate(episode_lengths.values()))
return { return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]), "from": torch.LongTensor([0] + cumulative_lengths[:-1]),
"to": torch.LongTensor(cumulative_lenghts), "to": torch.LongTensor(cumulative_lengths),
}
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
} }
def check_timestamps_sync( def check_timestamps_sync(
hf_dataset: datasets.Dataset, timestamps: np.ndarray,
episode_data_index: dict[str, torch.Tensor], episode_indices: np.ndarray,
episode_data_index: dict[str, np.ndarray],
fps: int, fps: int,
tolerance_s: float, tolerance_s: float,
raise_value_error: bool = True, raise_value_error: bool = True,
) -> bool: ) -> bool:
""" """
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
account for possible numerical error. to account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode Args:
# and the one at the start of the next episode since these are expected timestamps (np.ndarray): Array of timestamps in seconds.
# to be outside tolerance. episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
mask = torch.ones(len(diffs), dtype=torch.bool) episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
ignored_diffs = episode_data_index["to"][:-1] - 1 which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
Returns:
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
"timestamps and episode_indices should have the same shape. "
f"Found {timestamps.shape=} and {episode_indices.shape=}."
)
# Consecutive differences
diffs = np.diff(timestamps)
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask] filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance): # Check if all remaining diffs are within tolerance
if not np.all(filtered_within_tolerance):
# Track original indices before masking # Track original indices before masking
original_indices = torch.arange(len(diffs)) original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask] filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze() outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = [] outside_tolerances = []
for idx in outside_tolerance_indices: for idx in outside_tolerance_indices:
entry = { entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]], "timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx], "diff": diffs[idx],
"episode_index": episode_indices[idx].item(), "episode_index": episode_indices[idx].item()
if hasattr(episode_indices[idx], "item")
else episode_indices[idx],
} }
outside_tolerances.append(entry) outside_tolerances.append(entry)
if raise_value_error: if raise_value_error:
raise ValueError( raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range. f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection. This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}""" \n{pformat(outside_tolerances)}"""
) )
return False return False
@@ -434,7 +558,7 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {} delta_indices = {}
for key, delta_ts in delta_timestamps.items(): for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist() delta_indices[key] = [round(d * fps) for d in delta_ts]
return delta_indices return delta_indices
@@ -477,7 +601,6 @@ def create_lerobot_dataset_card(
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
""" """
card_tags = ["LeRobot"] card_tags = ["LeRobot"]
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
if tags: if tags:
card_tags += tags card_tags += tags
@@ -497,8 +620,180 @@ def create_lerobot_dataset_card(
], ],
) )
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
return DatasetCard.from_template( return DatasetCard.from_template(
card_data=card_data, card_data=card_data,
template_path=str(card_template_path), template_str=card_template,
**kwargs, **kwargs,
) )
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)

View File

@@ -26,13 +26,14 @@ from pathlib import Path
from textwrap import dedent from textwrap import dedent
from lerobot import available_datasets from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/") LOCAL_DIR = Path("data/")
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml") # spellchecker:off
ALOHA_MOBILE_INFO = { ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG), "robot_config": AlohaRobotConfig(),
"license": "mit", "license": "mit",
"url": "https://mobile-aloha.github.io/", "url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117", "paper": "https://arxiv.org/abs/2401.02117",
@@ -45,7 +46,7 @@ ALOHA_MOBILE_INFO = {
}""").lstrip(), }""").lstrip(),
} }
ALOHA_STATIC_INFO = { ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG), "robot_config": AlohaRobotConfig(),
"license": "mit", "license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/", "url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705", "paper": "https://arxiv.org/abs/2304.13705",
@@ -159,11 +160,11 @@ DATASETS = {
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_vinh_cup": { "aloha_static_vinh_cup": {
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.", "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_vinh_cup_left": { "aloha_static_vinh_cup_left": {
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.", "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
@@ -856,6 +857,7 @@ DATASETS = {
}""").lstrip(), }""").lstrip(),
}, },
} }
# spellchecker:on
def batch_convert(): def batch_convert():

View File

@@ -17,7 +17,7 @@
""" """
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English 2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning. for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
We support 3 different scenarios for these tasks (see instructions below): We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task. 1. Single task dataset: all episodes of your dataset have the same single task.
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
create_branch, create_branch,
create_lerobot_dataset_card, create_lerobot_dataset_card,
flatten_dict, flatten_dict,
get_hub_safe_version, get_safe_version,
load_json, load_json,
unflatten_dict, unflatten_dict,
write_json, write_json,
@@ -141,7 +141,8 @@ from lerobot.common.datasets.video_utils import (
get_image_pixel_channels, get_image_pixel_channels,
get_video_info, get_video_info,
) )
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_config
V16 = "v1.6" V16 = "v1.6"
V20 = "v2.0" V20 = "v2.0"
@@ -152,19 +153,18 @@ V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors" V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides) if robot_cfg.type in ["aloha", "koch"]:
if robot_cfg["robot_type"] in ["aloha", "koch"]:
state_names = [ state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
for arm in robot_cfg["follower_arms"] for arm in robot_cfg.follower_arms
for motor in robot_cfg["follower_arms"][arm]["motors"] for motor in robot_cfg.follower_arms[arm].motors
] ]
action_names = [ action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"] # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
for arm in robot_cfg["leader_arms"] for arm in robot_cfg.leader_arms
for motor in robot_cfg["leader_arms"][arm]["motors"] for motor in robot_cfg.leader_arms[arm].motors
] ]
# elif robot_cfg["robot_type"] == "stretch3": TODO # elif robot_cfg["robot_type"] == "stretch3": TODO
else: else:
@@ -173,7 +173,7 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
) )
return { return {
"robot_type": robot_cfg["robot_type"], "robot_type": robot_cfg.type,
"names": { "names": {
"observation.state": state_names, "observation.state": state_names,
"observation.effort": state_names, "observation.effort": state_names,
@@ -203,7 +203,10 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key]) torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]: def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
features = {} features = {}
for key, ft in dataset.features.items(): for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value): if isinstance(ft, datasets.Value):
@@ -224,11 +227,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
image = dataset[0][key] # Assuming first row image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image) channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels) shape = (image.height, image.width, channels)
names = ["height", "width", "channel"] names = ["height", "width", "channels"]
elif ft._type == "VideoFrame": elif ft._type == "VideoFrame":
dtype = "video" dtype = "video"
shape = None # Add shape later shape = None # Add shape later
names = ["height", "width", "channel"] names = ["height", "width", "channels"]
features[key] = { features[key] = {
"dtype": dtype, "dtype": dtype,
@@ -436,11 +439,11 @@ def convert_dataset(
single_task: str | None = None, single_task: str | None = None,
tasks_path: Path | None = None, tasks_path: Path | None = None,
tasks_col: Path | None = None, tasks_col: Path | None = None,
robot_config: dict | None = None, robot_config: RobotConfig | None = None,
test_branch: str | None = None, test_branch: str | None = None,
**card_kwargs, **card_kwargs,
): ):
v1 = get_hub_safe_version(repo_id, V16) v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True) v1x_dir.mkdir(parents=True, exist_ok=True)
@@ -532,7 +535,7 @@ def convert_dataset(
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None: if robot_config is not None:
robot_type = robot_config["robot_type"] robot_type = robot_config.type
repo_tags = [robot_type] repo_tags = [robot_type]
else: else:
robot_type = "unknown" robot_type = "unknown"
@@ -621,16 +624,10 @@ def main():
help="The path to a .json file containing one language instruction for each episode_index", help="The path to a .json file containing one language instruction for each episode_index",
) )
parser.add_argument( parser.add_argument(
"--robot-config", "--robot",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
type=str, type=str,
nargs="*", default=None,
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)", help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
) )
parser.add_argument( parser.add_argument(
"--local-dir", "--local-dir",
@@ -655,8 +652,10 @@ def main():
if not args.local_dir: if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2") args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None if args.robot is not None:
del args.robot_config, args.robot_overrides robot_config = make_robot_config(args.robot)
del args.robot
convert_dataset(**vars(args), robot_config=robot_config) convert_dataset(**vars(args), robot_config=robot_config)

View File

@@ -0,0 +1,73 @@
import logging
import traceback
from pathlib import Path
from datasets import get_dataset_config_info
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import INFO_PATH, write_info
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
LOCAL_DIR = Path("data/")
hub_api = HfApi()
def fix_dataset(repo_id: str) -> str:
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
return f"{repo_id}: skipped (not in {V20})."
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
commit_info = hub_api.upload_file(
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
path_in_repo=INFO_PATH,
repo_id=repo_id,
repo_type="dataset",
revision=V20,
commit_message="Remove 'language_instruction'",
create_pr=True,
)
return f"{repo_id}: success - PR: {commit_info.pr_url}"
def batch_fix():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "fix_features_v20.txt"
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
status = fix_dataset(repo_id)
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_fix()

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
"""
import traceback
from pathlib import Path
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
hub_api = HfApi()
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
status = f"{repo_id}: success (already in {V21})."
else:
convert_dataset(repo_id)
status = f"{repo_id}: success."
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_convert()

View File

@@ -0,0 +1,100 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
```bash
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
--repo-id=aliberts/koch_tutorial
```
"""
import argparse
import logging
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
class SuppressWarnings:
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger().setLevel(self.previous_level)
def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()
convert_dataset(**vars(args))

View File

@@ -0,0 +1,85 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
future.result()
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
)

View File

@@ -69,11 +69,11 @@ def decode_video_frames_torchvision(
# set the first and last requested timestamps # set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame # Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps[0] first_ts = min(timestamps)
last_ts = timestamps[-1] last_ts = max(timestamps)
# access closest key frame of the first requested frame # access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video) # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only) reader.seek(first_ts, keyframes_only=keyframes_only)

View File

@@ -0,0 +1 @@
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401

View File

@@ -0,0 +1,142 @@
import abc
from dataclasses import dataclass, field
import draccus
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.configs.types import FeatureType, PolicyFeature
@dataclass
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
task: str | None = None
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractproperty
def gym_kwargs(self) -> dict:
raise NotImplementedError()
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"environment_state": OBS_ENV,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}

View File

@@ -16,43 +16,54 @@
import importlib import importlib
import gymnasium as gym import gymnasium as gym
from omegaconf import DictConfig
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: def make_env_config(env_type: str, **kwargs) -> EnvConfig:
"""Makes a gym vector environment according to the evaluation config. if env_type == "aloha":
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
""" """
if n_envs is not None and n_envs < 1: if n_envs < 1:
raise ValueError("`n_envs must be at least 1") raise ValueError("`n_envs must be at least 1")
if cfg.env.name == "real_world": package_name = f"gym_{cfg.type}"
return
package_name = f"gym_{cfg.env.name}"
try: try:
importlib.import_module(package_name) importlib.import_module(package_name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print( print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
)
raise e raise e
gym_handle = f"{package_name}/{cfg.env.task}" gym_handle = f"{package_name}/{cfg.task}"
gym_kwgs = dict(cfg.env.get("gym", {}))
if cfg.env.get("episode_length"):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c) # batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls( env = env_cls(
[ [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
]
) )
return env return env

View File

@@ -18,8 +18,13 @@ import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import get_channel_first_image_shape
from lerobot.configs.types import FeatureType, PolicyFeature
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
observation: Dictionary of observation batches from a Gym vector environment. observation: Dictionary of observation batches from a Gym vector environment.
@@ -35,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
imgs = {"observation.image": observations["pixels"]} imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img) img = torch.from_numpy(img)
# sanity check that images are channel last # sanity check that images are channel last
@@ -60,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# requirement for "agent_pos" # requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {}
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)
else:
feature = ft
policy_key = env_cfg.features_map[key]
policy_features[policy_key] = feature
return policy_features

View File

@@ -1,246 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""
import logging
import os
import re
from glob import glob
from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"seed:{cfg.seed}",
]
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger:
"""Primary logger object. Logs either locally or using wandb.
The logger creates the following directory structure:
provided_log_dir
├── .hydra # hydra's configuration cache
├── checkpoints
│ ├── specific_checkpoint_name
│ │ ├── pretrained_model # Hugging Face pretrained model directory
│ │ │ ├── ...
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
| ├── another_specific_checkpoint_name
│ │ ├── ...
| ├── ...
│ └── last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=wandb_job_name,
notes=cfg.get("wandb", {}).get("notes"),
tags=cfg_to_group(cfg, return_list=True),
dir=log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
scheduler: LRScheduler | None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpoint(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@@ -0,0 +1 @@
from .optimizers import OptimizerConfig as OptimizerConfig

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.configs.train import TrainPipelineConfig
def make_optimizer_and_scheduler(
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
) -> tuple[Optimizer, LRScheduler | None]:
"""Generates the optimizer and scheduler based on configs.
Args:
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
Returns:
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.common.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
lr: float
weight_decay: float
grad_clip_norm: float
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@classmethod
def default_choice_name(cls) -> str | None:
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer:
raise NotImplementedError
@OptimizerConfig.register_subclass("adam")
@dataclass
class AdamConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.Adam(params, **kwargs)
@OptimizerConfig.register_subclass("adamw")
@dataclass
class AdamWConfig(OptimizerConfig):
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)
@OptimizerConfig.register_subclass("sgd")
@dataclass
class SGDConfig(OptimizerConfig):
lr: float = 1e-3
momentum: float = 0.0
dampening: float = 0.0
nesterov: bool = False
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE)
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
)
loaded_state_dict["param_groups"] = param_groups
optimizer.load_state_dict(loaded_state_dict)
return optimizer

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.datasets.utils import write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
num_warmup_steps: int
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractmethod
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
raise NotImplementedError
@LRSchedulerConfig.register_subclass("diffuser")
@dataclass
class DiffuserSchedulerConfig(LRSchedulerConfig):
name: str = "cosine"
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
return get_scheduler(**kwargs)
@LRSchedulerConfig.register_subclass("vqbet")
@dataclass
class VQBeTSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
num_vqvae_training_steps: int
num_cycles: float = 0.5
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
if current_step < self.num_vqvae_training_steps:
return float(1)
else:
adjusted_step = current_step - self.num_vqvae_training_steps
if adjusted_step < self.num_warmup_steps:
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0"""
num_warmup_steps: int
num_decay_steps: int
peak_lr: float
decay_lr: float
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
del num_training_steps
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)
return cosine_decay_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler

View File

@@ -0,0 +1,5 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -15,9 +15,14 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("act")
@dataclass @dataclass
class ACTConfig: class ACTConfig(PreTrainedConfig):
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@@ -59,7 +64,7 @@ class ACTConfig:
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution. convolution.
@@ -90,28 +95,11 @@ class ACTConfig:
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, list[int]] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.images.top": [3, 480, 640], "VISUAL": NormalizationMode.MEAN_STD,
"observation.state": [14], "STATE": NormalizationMode.MEAN_STD,
} "ACTION": NormalizationMode.MEAN_STD,
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
} }
) )
@@ -144,7 +132,14 @@ class ACTConfig:
dropout: float = 0.1 dropout: float = 0.1
kl_weight: float = 10.0 kl_weight: float = 10.0
# Training preset
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
def __post_init__(self): def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
@@ -164,8 +159,28 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
if (
not any(k.startswith("observation.image") for k in self.input_shapes) def get_optimizer_preset(self) -> AdamWConfig:
and "observation.environment_state" not in self.input_shapes return AdamWConfig(
): lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.") raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -29,32 +29,27 @@ import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class ACTPolicy( class ACTPolicy(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "act"],
):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
""" """
config_class = ACTConfig
name = "act" name = "act"
def __init__( def __init__(
self, self,
config: ACTConfig | None = None, config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
@@ -64,30 +59,46 @@ class ACTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used. that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__(config)
if config is None: config.validate_features()
config = ACTConfig() self.config = config
self.config: ACTConfig = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.model = ACT(config) self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None: if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset() self.reset()
def get_optim_params(self) -> dict:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": self.config.optimizer_lr_backbone,
},
]
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None: if self.config.temporal_ensemble_coeff is not None:
@@ -106,9 +117,11 @@ class ACTPolicy(
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over. # we are ensembling over.
@@ -131,12 +144,14 @@ class ACTPolicy(
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -154,11 +169,11 @@ class ACTPolicy(
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
) )
loss_dict["kld_loss"] = mean_kld.item() loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight loss = l1_loss + mean_kld * self.config.kl_weight
else: else:
loss_dict["loss"] = l1_loss loss = l1_loss
return loss_dict return loss, loss_dict
class ACTTemporalEnsembler: class ACTTemporalEnsembler:
@@ -288,31 +303,30 @@ class ACT(nn.Module):
""" """
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes super().__init__()
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) self.config = config
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
if self.use_robot_state: if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model self.config.robot_state_feature.shape[0], config.dim_model
) )
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear( self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model self.config.action_feature.shape[0],
config.dim_model,
) )
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
if self.use_robot_state: if self.config.robot_state_feature:
num_input_token_encoder += 1 num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
@@ -320,7 +334,7 @@ class ACT(nn.Module):
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
if self.use_images: if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights, weights=config.pretrained_backbone_weights,
@@ -337,27 +351,27 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)]. # [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state: if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear( self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model self.config.robot_state_feature.shape[0], config.dim_model
) )
if self.use_env_state: if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear( self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model self.config.env_state_feature.shape[0], config.dim_model
) )
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images: if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent n_1d_tokens = 1 # for the latent
if self.use_robot_state: if self.config.robot_state_feature:
n_1d_tokens += 1 n_1d_tokens += 1
if self.use_env_state: if self.config.env_state_feature:
n_1d_tokens += 1 n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images: if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
@@ -365,7 +379,7 @@ class ACT(nn.Module):
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder. # Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters() self._reset_parameters()
@@ -380,13 +394,13 @@ class ACT(nn.Module):
`batch` should have the following structure: `batch` should have the following structure:
{ {
"observation.state" (optional): (B, state_dim) batch of robot states. [robot_state_feature] (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images. [image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR AND/OR
"observation.environment_state": (B, env_dim) batch of environment states. [env_state_feature]: (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
} }
Returns: Returns:
@@ -395,9 +409,9 @@ class ACT(nn.Module):
latent dimension. latent dimension.
""" """
if self.config.use_vae and self.training: if self.config.use_vae and self.training:
assert ( assert "action" in batch, (
"action" in batch "actions must be provided when using the variational objective in training mode."
), "actions must be provided when using the variational objective in training mode." )
batch_size = ( batch_size = (
batch["observation.images"] batch["observation.images"]
@@ -411,12 +425,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.use_robot_state: if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_robot_state: if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else: else:
vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = [cls_embed, action_embed]
@@ -430,7 +444,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state) # sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token. # False means not a padding token.
cls_joint_is_pad = torch.full( cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1), (batch_size, 2 if self.config.robot_state_feature else 1),
False, False,
device=batch["observation.state"].device, device=batch["observation.state"].device,
) )
@@ -463,16 +477,16 @@ class ACT(nn.Module):
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token. # Robot state token.
if self.use_robot_state: if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token. # Environment state token.
if self.use_env_state: if self.config.env_state_feature:
encoder_in_tokens.append( encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"]) self.encoder_env_state_input_proj(batch["observation.environment_state"])
) )
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
if self.use_images: if self.config.image_features:
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []

View File

@@ -16,9 +16,15 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("diffusion")
@dataclass @dataclass
class DiffusionConfig: class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy. """Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations. Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -62,7 +68,7 @@ class DiffusionConfig:
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode). mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone. use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16). The group sizes are set to be about 16 (to be precise, feature_dim // 16).
@@ -93,7 +99,7 @@ class DiffusionConfig:
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`. spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults `LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same. to False as the original Diffusion Policy implementation does the same.
""" """
@@ -102,26 +108,17 @@ class DiffusionConfig:
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
input_shapes: dict[str, list[int]] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 96, 96], "VISUAL": NormalizationMode.MEAN_STD,
"observation.state": [2], "STATE": NormalizationMode.MIN_MAX,
} "ACTION": NormalizationMode.MIN_MAX,
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
} }
) )
# Normalization / Unnormalization # The original implementation doesn't sample frames for the last 7 steps,
input_normalization_modes: dict[str, str] = field( # which avoids excessive padding and leads to improved training results.
default_factory=lambda: { drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
@@ -154,39 +151,23 @@ class DiffusionConfig:
# Loss computation # Loss computation
do_mask_loss_for_padding: bool = False do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self): def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:
raise ValueError( raise ValueError(
@@ -207,3 +188,50 @@ class DiffusionConfig:
"The horizon should be an integer multiple of the downsampling factor (which is determined " "The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
) )
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import (
get_device_from_parameters, get_device_from_parameters,
get_dtype_from_parameters, get_dtype_from_parameters,
get_output_shape,
populate_queues, populate_queues,
) )
class DiffusionPolicy( class DiffusionPolicy(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "diffusion-policy"],
):
""" """
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
""" """
config_class = DiffusionConfig
name = "diffusion" name = "diffusion"
def __init__( def __init__(
self, self,
config: DiffusionConfig | None = None, config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
@@ -69,18 +66,16 @@ class DiffusionPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used. that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__(config)
if config is None: config.validate_features()
config = DiffusionConfig()
self.config = config self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
@@ -88,20 +83,20 @@ class DiffusionPolicy(
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset() self.reset()
def get_optim_params(self) -> dict:
return self.diffusion.parameters()
def reset(self): def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`""" """Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = { self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
if len(self.expected_image_keys) > 0: if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad @torch.no_grad
@@ -127,9 +122,11 @@ class DiffusionPolicy(
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key. # Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -146,15 +143,18 @@ class DiffusionPolicy(
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} # no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
@@ -176,12 +176,9 @@ class DiffusionModel(nn.Module):
self.config = config self.config = config
# Build observation encoders (depending on which observations are provided). # Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0] global_cond_dim = self.config.robot_state_feature.shape[0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) if self.config.image_features:
self._use_images = False num_images = len(self.config.image_features)
self._use_env_state = False
if num_images > 0:
self._use_images = True
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)] encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders) self.rgb_encoder = nn.ModuleList(encoders)
@@ -189,9 +186,8 @@ class DiffusionModel(nn.Module):
else: else:
self.rgb_encoder = DiffusionRgbEncoder(config) self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images global_cond_dim += self.rgb_encoder.feature_dim * num_images
if "observation.environment_state" in config.input_shapes: if self.config.env_state_feature:
self._use_env_state = True global_cond_dim += self.config.env_state_feature.shape[0]
global_cond_dim += config.input_shapes["observation.environment_state"][0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
@@ -220,7 +216,7 @@ class DiffusionModel(nn.Module):
# Sample prior. # Sample prior.
sample = torch.randn( sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]), size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype, dtype=dtype,
device=device, device=device,
generator=generator, generator=generator,
@@ -242,10 +238,10 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector.""" """Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch["observation.state"]] global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features. # Extract image features.
if self._use_images: if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first. # Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
@@ -272,8 +268,8 @@ class DiffusionModel(nn.Module):
) )
global_cond_feats.append(img_features) global_cond_feats.append(img_features)
if self._use_env_state: if self.config.env_state_feature:
global_cond_feats.append(batch["observation.environment_state"]) global_cond_feats.append(batch[OBS_ENV])
# Concatenate features then flatten to (B, global_cond_dim). # Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
@@ -443,7 +439,7 @@ class SpatialSoftmax(nn.Module):
class DiffusionRgbEncoder(nn.Module): class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector. """Encodes an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first. Includes the ability to normalize and crop the image first.
""" """
@@ -482,19 +478,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should # The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`. # height and width from `config.image_features`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: we have a check in the config class to make sure all images have the same shape. # Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0] images_shape = next(iter(config.image_features.values())).shape
dummy_input_h_w = ( dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -611,7 +604,7 @@ class DiffusionConditionalUnet1d(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these. # just reverse these.
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list( in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True) zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
) )
@@ -666,7 +659,7 @@ class DiffusionConditionalUnet1d(nn.Module):
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1), nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
) )
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:

View File

@@ -13,99 +13,141 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from omegaconf import DictConfig, OmegaConf import torch
from torch import nn
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.utils.utils import get_safe_torch_device from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): def get_policy_class(name: str) -> PreTrainedPolicy:
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
# issues with mutable defaults. This filter changes all lists to tuples.
def list_to_tuple(item):
return tuple(item) if isinstance(item, list) else item
policy_cfg = policy_cfg_class(
**{
k: list_to_tuple(v)
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
return policy_cfg
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" """Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc": if name == "tdmpc":
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig return TDMPCPolicy
elif name == "diffusion": elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy, DiffusionConfig return DiffusionPolicy
elif name == "act": elif name == "act":
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
return ACTPolicy, ACTConfig return ACTPolicy
elif name == "vqbet": elif name == "vqbet":
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig return VQBeTPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy( def make_policy(
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None cfg: PreTrainedConfig,
) -> Policy: device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class. """Make an instance of a policy class.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
Args: Args:
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
provided, only `hydra_cfg.policy.name` is used while everything else is ignored. be loaded with the weights from that path.
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a device (str): the device to load the policy onto.
directory containing weights saved using `Policy.save_pretrained`. Note that providing this ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`. statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
be provided when initializing a new policy, and must not be provided when loading a pretrained provided if ds_meta is not. Defaults to None.
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
""" """
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): if bool(ds_meta) == bool(env_cfg):
raise ValueError( raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
) )
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cls = get_policy_class(cfg.type)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) kwargs = {}
if pretrained_policy_name_or_path is None: if ds_meta is not None:
# Make a fresh policy. features = dataset_to_policy_features(ds_meta.features)
policy = policy_cls(policy_cfg, dataset_stats) kwargs["dataset_stats"] = ds_meta.stats
else: else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, kwargs["pretrained_name_or_path"] = cfg.pretrained_path
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in policy = policy_cls.from_pretrained(**kwargs)
# huggingface_hub should make it possible to avoid the hack: else:
# https://github.com/huggingface/huggingface_hub/pull/2274. # Make a fresh policy.
policy = policy_cls(policy_cfg) policy = policy_cls(**kwargs)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
return policy return policy

View File

@@ -13,13 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers( def create_stats_buffers(
shapes: dict[str, list[int]], features: dict[str, PolicyFeature],
modes: dict[str, str], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]: ) -> dict[str, dict[str, nn.ParameterDict]]:
""" """
@@ -34,12 +37,16 @@ def create_stats_buffers(
""" """
stats_buffers = {} stats_buffers = {}
for key, mode in modes.items(): for key, ft in features.items():
assert mode in ["mean_std", "min_max"] norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape = tuple(shapes[key]) assert isinstance(norm_mode, NormalizationMode)
if "image" in key: shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks # sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape c, h, w = shape
@@ -52,7 +59,7 @@ def create_stats_buffers(
# we assert they are not infinity anymore. # we assert they are not infinity anymore.
buffer = {} buffer = {}
if mode == "mean_std": if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
@@ -61,7 +68,7 @@ def create_stats_buffers(
"std": nn.Parameter(std, requires_grad=False), "std": nn.Parameter(std, requires_grad=False),
} }
) )
elif mode == "min_max": elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
@@ -71,17 +78,29 @@ def create_stats_buffers(
} }
) )
if stats is not None: # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated if stats:
# tensors anywhere (for example, when we use the same stats for normalization and if isinstance(stats[key]["mean"], np.ndarray):
# unnormalization). See the logic here if norm_mode is NormalizationMode.MEAN_STD:
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
if mode == "mean_std": buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
buffer["mean"].data = stats[key]["mean"].clone() elif norm_mode is NormalizationMode.MIN_MAX:
buffer["std"].data = stats[key]["std"].clone() buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
elif mode == "min_max": buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
buffer["min"].data = stats[key]["min"].clone() elif isinstance(stats[key]["mean"], torch.Tensor):
buffer["max"].data = stats[key]["max"].clone() # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
@@ -99,8 +118,8 @@ class Normalize(nn.Module):
def __init__( def __init__(
self, self,
shapes: dict[str, list[int]], features: dict[str, PolicyFeature],
modes: dict[str, str], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
@@ -122,10 +141,10 @@ class Normalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict. dataset is not needed to get the stats, since they are already in the policy state_dict.
""" """
super().__init__() super().__init__()
self.shapes = shapes self.features = features
self.modes = modes self.norm_map = norm_map
self.stats = stats self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats) stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -133,16 +152,24 @@ class Normalize(nn.Module):
@torch.no_grad @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items(): for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max": elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -152,7 +179,7 @@ class Normalize(nn.Module):
# normalize to [-1, 1] # normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1 batch[key] = batch[key] * 2 - 1
else: else:
raise ValueError(mode) raise ValueError(norm_mode)
return batch return batch
@@ -164,8 +191,8 @@ class Unnormalize(nn.Module):
def __init__( def __init__(
self, self,
shapes: dict[str, list[int]], features: dict[str, PolicyFeature],
modes: dict[str, str], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
@@ -187,11 +214,11 @@ class Unnormalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict. dataset is not needed to get the stats, since they are already in the policy state_dict.
""" """
super().__init__() super().__init__()
self.shapes = shapes self.features = features
self.modes = modes self.norm_map = norm_map
self.stats = stats self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats) stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -199,16 +226,23 @@ class Unnormalize(nn.Module):
@torch.no_grad @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items(): for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif mode == "min_max": elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -216,5 +250,5 @@ class Unnormalize(nn.Module):
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max - min) + min
else: else:
raise ValueError(mode) raise ValueError(norm_mode)
return batch return batch

View File

@@ -0,0 +1,134 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,68 @@
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
torch.backends.cudnn.benchmark = True
def main():
device = "cuda"
dataset_repo_id = "danaaubakirova/koch_test"
# model_name = "pi0_base"
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = "lerobot/pi0"
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(iter(dataloader))
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
warmup_iters = 10
benchmark_iters = 30
# Warmup
for _ in range(warmup_iters):
torch.cuda.synchronize()
policy.select_action(batch)
policy.reset()
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(benchmark_iters):
policy.select_action(batch)
policy.reset()
end_event.record()
# Synchronize and measure time
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_per_iter = elapsed_time_ms / benchmark_iters
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -0,0 +1,117 @@
import json
import pickle
from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool:
tensor = tensor.float()
print(f"Shape: {tensor.shape}")
print(f"Mean: {tensor.mean().item()}")
print(f"Std: {tensor.std().item()}")
print(f"Min: {tensor.min().item()}")
print(f"Max: {tensor.max().item()}")
def main():
num_motors = 14
device = "cuda"
# model_name = "pi0_aloha_towel"
model_name = "pi0_aloha_sim"
if model_name == "pi0_aloha_towel":
dataset_repo_id = "lerobot/aloha_static_towel"
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
with open(save_dir / "example.pkl", "rb") as f:
example = pickle.load(f)
with open(save_dir / "outputs.pkl", "rb") as f:
outputs = pickle.load(f)
with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
norm_stats = json.load(f)
# Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
del batch["observation.images.cam_low"]
elif model_name == "pi0_aloha_sim":
batch["observation.images.top"] = batch["observation.images.cam_high"]
del batch["observation.images.cam_high"]
# Batchify
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], str):
batch[key] = [batch[key]]
else:
raise ValueError(f"{key}, {batch[key]}")
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
from lerobot.common import policies # noqa
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
# print("losses")
# display(loss_dict["losses_after_forward"])
# print("pi_losses")
# display(pi_losses)
actions = []
for _ in range(50):
action = policy.select_action(batch, noise=noise)
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch["action"]
print("actions")
display(actions)
print()
print("pi_actions")
display(pi_actions)
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,70 @@
from transformers import GemmaConfig, PaliGemmaConfig
def get_paligemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
image_size = 224 # image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config
def get_gemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 1024,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 4096,
"is_encoder_decoder": False,
}
final_config = GemmaConfig()
final_config.update(text_config)
return final_config

View File

@@ -0,0 +1,423 @@
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
```bash
cd ~/code/openpi
source .venv/bin/activate
```
Example downloading parameters:
```bash
python
>>> import openpi.shared.download as download
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
>>> download.maybe_download(path)
```
Converting pi0_base:
```python
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
```
```python
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
```
"""
import argparse
import pathlib
import jax
import numpy as np
import orbax.checkpoint as ocp
import torch
from jax.sharding import SingleDeviceSharding
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
get_gemma_config,
get_paligemma_config,
)
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
def slice_paligemma_state_dict(state_dict, config):
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# fmt: off
# patch embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
3, 2, 0, 1
)
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
# positional embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
# multimodal projector
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
# text decoder (gemma)
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
expert_dict = {}
final_state_dict = {}
for key, value in state_dict.items():
if key not in [
f"llm/final_norm_1/scale{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
]:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, num_expert=1):
# fmt: off
# text decoder (gemma)
# no embedding vector, the expert just has the decoder layers
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
# fmt: on
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def flatten_for_memory(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_memory(v, new_key))
else:
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
return out
def flatten_for_npz(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_npz(v, new_key))
else:
# bf16/f32 here?
out[new_key] = np.array(v)
return out
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
params_path = pathlib.Path(checkpoint_dir).resolve()
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(params_path)
print("Metadata keys:", list(metadata.keys()))
params_name = "params"
item = {params_name: metadata[params_name]}
device = jax.local_devices()[0] # Use the first local device
sharding = SingleDeviceSharding(device)
restored = checkpointer.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree_util.tree_map(
lambda _: ocp.ArrayRestoreArgs(
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
sharding=sharding,
),
item,
),
transforms={},
),
)
params = restored[params_name]
# get params for PaliGemma
pali_params = params["PaliGemma"]
del params["PaliGemma"]
pali_params_flat = flatten_for_npz(pali_params)
return {"paligemma_params": pali_params_flat, "projection_params": params}
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
"""Update dictionary keys by adding a prefix."""
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
type=str,
help="Path to the ocdbt checkpoint",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
# tokenizer is identical to paligemma, it appears
parser.add_argument(
"--tokenizer_hub_id",
default="google/paligemma-3b-pt-224",
type=str,
help="Hub path to the tokenizer to save",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Path to save converted weights to",
)
args = parser.parse_args()
convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir,
precision=args.precision,
tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path,
)

View File

@@ -0,0 +1,127 @@
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
# @torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_att_heads = 8
num_key_value_heads = 1
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.to(torch.float32)
key_states = key_states.to(torch.float32)
value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
_compile=False,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
block_mask = create_block_mask(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output

View File

@@ -0,0 +1,732 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
pretrained with VLM default parameters before pi0 finetuning:
```bash
python lerobot/scripts/train.py \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = Pi0Policy.from_pretrained("lerobot/pi0")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0Policy(PreTrainedPolicy):
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
config_class = PI0Config
name = "pi0"
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_ROBOT].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
class PI0FlowMatching(nn.Module):
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌┴─────┐ │
│ kv cache │Gemma │ │
│ ┌──────────►│Expert│ │
│ │ │ │ │
│ ┌┴────────┐ │x 10 │ │
│ │ │ └▲──▲──┘ │
│ │PaliGemma│ │ │ │
│ │ │ │ robot state │
│ │ │ noise │
│ └▲──▲─────┘ │
│ │ │ │
│ │ image(s) │
│ language tokens │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
embs = []
pad_masks = []
att_masks = []
# TODO: remove for loop
for (
img,
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
)
time_emb = time_emb.type(dtype=dtype)
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.n_action_steps :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.n_action_steps :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -0,0 +1,403 @@
from typing import List, Optional, Union
import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
class PaliGemmaWithExpertConfig(PretrainedConfig):
model_type = "PaliGemmaWithExpertModel"
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
def __init__(
self,
paligemma_config: dict | None = None,
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
if paligemma_config is None:
# Default config from Pi0
self.paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": "float32",
"vocab_size": 257152,
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_fast",
"torch_dtype": "float32",
"vision_use_head": False,
},
)
elif isinstance(self.paligemma_config, dict):
# Override Pi0 default config for PaliGemma
if "model_type" not in gemma_expert_config:
paligemma_config["model_type"] = "paligemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.paligemma_config = cfg_cls(**paligemma_config)
if gemma_expert_config is None:
# Default config from Pi0
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation="gelu_pytorch_tanh",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.48.1",
use_cache=True,
vocab_size=257152,
)
elif isinstance(self.gemma_expert_config, dict):
# Override Pi0 default config for Gemma Expert
if "model_type" not in gemma_expert_config:
gemma_expert_config["model_type"] = "gemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
super().__init__(**kwargs)
def __post_init__(self):
super().__post_init__()
if self.train_expert_only and not self.freeze_vision_encoder:
raise ValueError(
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
)
if self.attention_implementation not in ["eager", "fa2", "flex"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for params in self.paligemma.vision_tower.parameters():
params.requires_grad = False
if self.config.train_expert_only:
self.paligemma.eval()
for params in self.paligemma.parameters():
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.config.train_expert_only:
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
params_to_change_dtype = [
"language_model.model.layers",
"gemma_expert.model.layers",
"vision_tower",
"multi_modal",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
return self.paligemma.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.paligemma.language_model.model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.paligemma.config.text_config.num_hidden_layers
head_dim = self.paligemma.config.text_config.head_dim
for layer_idx in range(num_layers):
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
continue
layer = models[i].layers[layer_idx]
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
query_states = apply_rope(query_states, position_ids)
key_states = apply_rope(key_states, position_ids)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
if hidden_states is not None:
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.config.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.config.attention_implementation == "flex":
attention_interface = flex_attention_forward
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy.
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
"""
name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""

View File

@@ -0,0 +1,187 @@
import abc
import logging
import os
from pathlib import Path
from typing import Type, TypeVar
import packaging
import safetensors
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")
DEFAULT_POLICY_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
"""
config_class: None
name: None
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PreTrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PreTrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
"""
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
"""
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
policy.to(map_location)
policy.eval()
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
# card = ModelCard.from_template(
# card_data=self._hub_mixin_info.model_card_data,
# template_str=self._hub_mixin_info.model_card_template,
# repo_url=self._hub_mixin_info.repo_url,
# docs_url=self._hub_mixin_info.docs_url,
# **kwargs,
# )
# return card
@abc.abstractmethod
def get_optim_params(self) -> dict:
"""
Returns the policy-specific parameters dict to be passed on to the optimizer.
"""
raise NotImplementedError
@abc.abstractmethod
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
raise NotImplementedError

View File

@@ -16,9 +16,14 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass @dataclass
class TDMPCConfig: class TDMPCConfig(PreTrainedConfig):
"""Configuration class for TDMPCPolicy. """Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
@@ -71,7 +76,7 @@ class TDMPCConfig:
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero. be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM). trajectory values (this is the λ coefficient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration. n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM. elites, when updating the gaussian parameters for CEM.
@@ -102,27 +107,19 @@ class TDMPCConfig:
""" """
# Input / output structure. # Input / output structure.
n_obs_steps: int = 1
n_action_repeats: int = 2 n_action_repeats: int = 2
horizon: int = 5 horizon: int = 5
n_action_steps: int = 1 n_action_steps: int = 1
input_shapes: dict[str, list[int]] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 84, 84], "VISUAL": NormalizationMode.IDENTITY,
"observation.state": [4], "STATE": NormalizationMode.IDENTITY,
"ENV": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
} }
) )
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling. # Architecture / modeling.
# Neural networks. # Neural networks.
@@ -159,32 +156,27 @@ class TDMPCConfig:
# Target model. # Target model.
target_model_momentum: float = 0.995 target_model_momentum: float = 0.995
# Training presets
optimizer_lr: float = 3e-4
def __post_init__(self): def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
)
if len(image_keys) > 0:
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
) )
if self.output_normalization_modes != {"action": "min_max"}: if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError( raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly " "TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information." "information."
) )
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.n_action_steps > 1: if self.n_action_steps > 1:
if self.n_action_repeats != 1: if self.n_action_repeats != 1:
raise ValueError( raise ValueError(
@@ -194,3 +186,35 @@ class TDMPCConfig:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon: if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# There should only be one image key.
if len(self.image_features) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
)
if len(self.image_features) > 0:
image_ft = next(iter(self.image_features.values()))
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
@property
def observation_delta_indices(self) -> list:
return list(range(self.horizon + 1))
@property
def action_delta_indices(self) -> list:
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return list(range(self.horizon))

View File

@@ -33,21 +33,16 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
class TDMPCPolicy( class TDMPCPolicy(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc"],
):
"""Implementation of TD-MPC learning + inference. """Implementation of TD-MPC learning + inference.
Please note several warnings for this policy. Please note several warnings for this policy.
@@ -65,11 +60,10 @@ class TDMPCPolicy(
match our xarm environment. match our xarm environment.
""" """
config_class = TDMPCConfig
name = "tdmpc" name = "tdmpc"
def __init__( def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
""" """
Args: Args:
config: Policy configuration class instance or None, in which case the default instantiation of config: Policy configuration class instance or None, in which case the default instantiation of
@@ -77,42 +71,28 @@ class TDMPCPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used. that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__(config)
config.validate_features()
if config is None:
config = TDMPCConfig()
self.config = config self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config) self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model) self.model_target = deepcopy(self.model)
for param in self.model_target.parameters(): for param in self.model_target.parameters():
param.requires_grad = False param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.reset() self.reset()
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self): def reset(self):
""" """
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
@@ -122,9 +102,9 @@ class TDMPCPolicy(
"observation.state": deque(maxlen=1), "observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
} }
if self._use_image: if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1) self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step. # CEM for the next step.
@@ -134,9 +114,9 @@ class TDMPCPolicy(
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self._use_image: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key] batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -151,9 +131,9 @@ class TDMPCPolicy(
# NOTE: Order of observations matters here. # NOTE: Order of observations matters here.
encode_keys = [] encode_keys = []
if self._use_image: if self.config.image_features:
encode_keys.append("observation.image") encode_keys.append("observation.image")
if self._use_env_state: if self.config.env_state_feature:
encode_keys.append("observation.environment_state") encode_keys.append("observation.environment_state")
encode_keys.append("observation.state") encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys}) z = self.model.encode({k: batch[k] for k in encode_keys})
@@ -196,7 +176,7 @@ class TDMPCPolicy(
self.config.horizon, self.config.horizon,
self.config.n_pi_samples, self.config.n_pi_samples,
batch_size, batch_size,
self.config.output_shapes["action"][0], self.config.action_feature.shape[0],
device=device, device=device,
) )
if self.config.n_pi_samples > 0: if self.config.n_pi_samples > 0:
@@ -215,7 +195,7 @@ class TDMPCPolicy(
# algorithm. # algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM). # The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros( mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
) )
# Maybe warm start CEM with the mean from the previous step. # Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None: if self._prev_mean is not None:
@@ -228,7 +208,7 @@ class TDMPCPolicy(
self.config.horizon, self.config.horizon,
self.config.n_gaussian_samples, self.config.n_gaussian_samples,
batch_size, batch_size,
self.config.output_shapes["action"][0], self.config.action_feature.shape[0],
device=std.device, device=std.device,
) )
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
@@ -322,7 +302,7 @@ class TDMPCPolicy(
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats. Returns a dictionary with loss as a tensor, and other information as native floats.
@@ -330,16 +310,16 @@ class TDMPCPolicy(
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self._use_image: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key] batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
info = {} info = {}
# (b, t) -> (t, b) # (b, t) -> (t, b)
for key in batch: for key in batch:
if batch[key].ndim > 1: if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0) batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b, action_dim) action = batch["action"] # (t, b, action_dim)
@@ -347,7 +327,7 @@ class TDMPCPolicy(
observations = {k: v for k, v in batch.items() if k.startswith("observation.")} observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations. # Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0: if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten( observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"], observations["observation.image"],
@@ -360,7 +340,7 @@ class TDMPCPolicy(
current_observation[k] = observations[k][0] current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:] next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[ horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state" "observation.image" if self.config.image_features else "observation.environment_state"
].shape[:2] ].shape[:2]
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
@@ -515,17 +495,16 @@ class TDMPCPolicy(
"Q_value_loss": q_value_loss.item(), "Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(), "V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(), "pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon, "sum_loss": loss.item() * self.config.horizon,
} }
) )
# Undo (b, t) -> (t, b). # Undo (b, t) -> (t, b).
for key in batch: for key in batch:
if batch[key].ndim > 1: if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0) batch[key] = batch[key].transpose(1, 0)
return info return loss, info
def update(self): def update(self):
"""Update the target model's parameters with an EMA step.""" """Update the target model's parameters with an EMA step."""
@@ -543,7 +522,7 @@ class TDMPCTOLD(nn.Module):
self.config = config self.config = config
self._encoder = TDMPCObservationEncoder(config) self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential( self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Mish(), nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -554,7 +533,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(), nn.Sigmoid(),
) )
self._reward = nn.Sequential( self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Mish(), nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -569,12 +548,12 @@ class TDMPCTOLD(nn.Module):
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Mish(), nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]), nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
) )
self._Qs = nn.ModuleList( self._Qs = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim), nn.LayerNorm(config.mlp_dim),
nn.Tanh(), nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim), nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -615,9 +594,9 @@ class TDMPCTOLD(nn.Module):
self.apply(_apply_fn) self.apply(_apply_fn)
for m in [self._reward, *self._Qs]: for m in [self._reward, *self._Qs]:
assert isinstance( assert isinstance(m[-1], nn.Linear), (
m[-1], nn.Linear "Sanity check. The last linear layer needs 0 initialization on weights."
), "Sanity check. The last linear layer needs 0 initialization on weights." )
nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
@@ -714,10 +693,13 @@ class TDMPCObservationEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
if "observation.image" in config.input_shapes: if config.image_features:
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 next(iter(config.image_features.values())).shape[0],
config.image_encoder_hidden_dim,
7,
stride=2,
), ),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
@@ -727,9 +709,8 @@ class TDMPCObservationEncoder(nn.Module):
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(), nn.ReLU(),
) )
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) dummy_shape = (1, *next(iter(config.image_features.values())).shape)
with torch.inference_mode(): out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend( self.image_enc_layers.extend(
nn.Sequential( nn.Sequential(
nn.Flatten(), nn.Flatten(),
@@ -738,19 +719,19 @@ class TDMPCObservationEncoder(nn.Module):
nn.Sigmoid(), nn.Sigmoid(),
) )
) )
if "observation.state" in config.input_shapes:
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(), nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Sigmoid(), nn.Sigmoid(),
) )
if "observation.environment_state" in config.input_shapes:
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(), nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
@@ -765,12 +746,16 @@ class TDMPCObservationEncoder(nn.Module):
""" """
feat = [] feat = []
# NOTE: Order of observations matters here. # NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes: if self.config.image_features:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"])) feat.append(
if "observation.environment_state" in self.config.input_shapes: flatten_forward_unflatten(
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
if "observation.state" in self.config.input_shapes: )
feat.append(self.state_enc_layers(obs_dict["observation.state"])) )
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
from torch import nn from torch import nn
@@ -47,3 +48,20 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype. Note: assumes that all parameters have the same dtype.
""" """
return next(iter(module.parameters())).dtype return next(iter(module.parameters())).dtype
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
"""
Calculates the output shape of a PyTorch module given an input shape.
Args:
module (nn.Module): a PyTorch module
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
Returns:
tuple: The output shape of the module.
"""
dummy_input = torch.zeros(size=input_shape)
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)

View File

@@ -18,9 +18,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("vqbet")
@dataclass @dataclass
class VQBeTConfig: class VQBeTConfig(PreTrainedConfig):
"""Configuration class for VQ-BeT. """Configuration class for VQ-BeT.
Defaults are configured for training with PushT providing proprioceptive and single camera observations. Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -60,7 +66,7 @@ class VQBeTConfig:
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode). mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone. use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16). The group sizes are set to be about 16 (to be precise, feature_dim // 16).
@@ -90,26 +96,13 @@ class VQBeTConfig:
n_action_pred_token: int = 3 n_action_pred_token: int = 3
action_chunk_size: int = 5 action_chunk_size: int = 5
input_shapes: dict[str, list[int]] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 96, 96], "VISUAL": NormalizationMode.IDENTITY,
"observation.state": [2], "STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
} }
) )
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
@@ -139,29 +132,69 @@ class VQBeTConfig:
bet_softmax_temperature: float = 0.1 bet_softmax_temperature: float = 0.1
sequentially_select: bool = False sequentially_select: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
optimizer_vqvae_lr: float = 1e-3
optimizer_vqvae_weight_decay: float = 1e-4
scheduler_warmup_steps: int = 500
def __post_init__(self): def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
return VQBeTSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
num_vqvae_training_steps=self.n_vqvae_training_steps,
)
def validate_features(self) -> None:
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
# assert len(image_keys) == 1
if not len(self.image_features) == 1:
raise ValueError("You must provide only one image among the inputs.")
if self.crop_shape is not None: if self.crop_shape is not None:
for image_key in image_keys: for key, image_ft in self.image_features.items():
if ( if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError( raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for " f"for `crop_shape` and {image_ft.shape} for "
"`input_shapes[{image_key}]`." f"`{key}`."
) )
# Check that all input images have the same shape. # Check that all input images have the same shape.
first_image_key = next(iter(image_keys)) first_image_key, first_image_ft = next(iter(self.image_features.items()))
for image_key in image_keys: for key, image_ft in self.image_features.items():
if self.input_shapes[image_key] != self.input_shapes[first_image_key]: if image_ft.shape != first_image_ft.shape:
raise ValueError( raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
"expect all image shapes to match."
) )
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import warnings import warnings
from collections import deque from collections import deque
from typing import Callable, List from typing import Callable, List
@@ -26,29 +25,23 @@ import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
# ruff: noqa: N806 # ruff: noqa: N806
class VQBeTPolicy( class VQBeTPolicy(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vqbet"],
):
""" """
VQ-BeT Policy as per "Behavior Generation with Latent Actions" VQ-BeT Policy as per "Behavior Generation with Latent Actions"
""" """
config_class = VQBeTConfig
name = "vqbet" name = "vqbet"
def __init__( def __init__(
@@ -63,26 +56,62 @@ class VQBeTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used. that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__(config)
if config is None: config.validate_features()
config = VQBeTConfig()
self.config = config self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
)
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.vqbet = VQBeTModel(config) self.vqbet = VQBeTModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset() self.reset()
def get_optim_params(self) -> dict:
vqvae_params = (
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(self.vqbet.rgb_encoder.parameters())
+ list(self.vqbet.state_projector.parameters())
+ list(self.vqbet.rgb_feature_projector.parameters())
+ [self.vqbet.action_token]
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
return [
{
"params": decay_params,
},
{
"params": vqvae_params,
"weight_decay": self.config.optimizer_vqvae_weight_decay,
"lr": self.config.optimizer_vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
},
]
def reset(self): def reset(self):
""" """
Clear observation and action queues. Should be called on `env.reset()` Clear observation and action queues. Should be called on `env.reset()`
@@ -105,7 +134,7 @@ class VQBeTPolicy(
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key. # Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -127,11 +156,11 @@ class VQBeTPolicy(
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item(): if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -141,16 +170,16 @@ class VQBeTPolicy(
loss, n_different_codes, n_different_combinations, recon_l1_error = ( loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
) )
return { return loss, {
"loss": loss,
"n_different_codes": n_different_codes, "n_different_codes": n_different_codes,
"n_different_combinations": n_different_combinations, "n_different_combinations": n_different_combinations,
"recon_l1_error": recon_l1_error, "recon_l1_error": recon_l1_error,
} }
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts. # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False) _, loss_dict = self.vqbet(batch, rollout=False)
loss = loss_dict.pop("loss")
return loss_dict return loss, loss_dict
class SpatialSoftmax(nn.Module): class SpatialSoftmax(nn.Module):
@@ -288,14 +317,14 @@ class VQBeTModel(nn.Module):
self.config = config self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config) self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) self.num_images = len(self.config.image_features)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above. # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results. # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP( self.state_projector = MLP(
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim] config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
) )
self.rgb_feature_projector = MLP( self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -313,7 +342,7 @@ class VQBeTModel(nn.Module):
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
) )
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"}) assert set(batch).issuperset({"observation.state", "observation.images"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@@ -350,10 +379,10 @@ class VQBeTModel(nn.Module):
# get action features (pass through GPT) # get action features (pass through GPT)
features = self.policy(input_tokens) features = self.policy(input_tokens)
# len(self.config.input_shapes) is the number of different observation modes. # len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens. # this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len( historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_shapes self.config.input_features
) )
# only extract the output tokens at the position of action query: # only extract the output tokens at the position of action query:
@@ -392,7 +421,7 @@ class VQBeTHead(nn.Module):
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers. self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT, The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`. and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
""" """
super().__init__() super().__init__()
@@ -419,7 +448,7 @@ class VQBeTHead(nn.Module):
self.vqvae_model.vqvae_num_layers self.vqvae_model.vqvae_num_layers
* self.config.vqvae_n_embed * self.config.vqvae_n_embed
* config.action_chunk_size * config.action_chunk_size
* config.output_shapes["action"][0], * config.action_feature.shape[0],
], ],
) )
# loss # loss
@@ -453,10 +482,10 @@ class VQBeTHead(nn.Module):
param.requires_grad = False param.requires_grad = False
return loss, n_different_codes, n_different_combinations, recon_l1_error return loss, n_different_codes, n_different_combinations, recon_l1_error
def forward(self, x, **kwargs): def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT # N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape N, T, _ = x.shape
# we calculate N and T side parallely. Thus, the dimensions would be # we calculate N and T side parallelly. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension) # (batch size * number of action query tokens, action chunk size, action dimension)
x = einops.rearrange(x, "N T WA -> (N T) WA") x = einops.rearrange(x, "N T WA -> (N T) WA")
@@ -623,84 +652,6 @@ class VQBeTHead(nn.Module):
return loss_dict return loss_dict
class VQBeTOptimizer(torch.optim.Adam):
def __init__(self, policy, cfg):
vqvae_params = (
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(policy.vqbet.rgb_encoder.parameters())
+ list(policy.vqbet.state_projector.parameters())
+ list(policy.vqbet.rgb_feature_projector.parameters())
+ [policy.vqbet.action_token]
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
optim_groups = [
{
"params": decay_params,
"weight_decay": cfg.training.adam_weight_decay,
"lr": cfg.training.lr,
},
{
"params": vqvae_params,
"weight_decay": 0.0001,
"lr": cfg.training.vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"lr": cfg.training.lr,
},
]
super().__init__(
optim_groups,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
)
class VQBeTScheduler(nn.Module):
def __init__(self, optimizer, cfg):
super().__init__()
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
num_warmup_steps = cfg.training.lr_warmup_steps
num_training_steps = cfg.training.offline_steps
num_cycles = 0.5
def lr_lambda(current_step):
if current_step < n_vqvae_training_steps:
return float(1)
else:
current_step = current_step - n_vqvae_training_steps
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
def step(self):
self.lr_scheduler.step()
class VQBeTRgbEncoder(nn.Module): class VQBeTRgbEncoder(nn.Module):
"""Encode an RGB image into a 1D feature vector. """Encode an RGB image into a 1D feature vector.
@@ -743,19 +694,15 @@ class VQBeTRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should # The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`. # height and width from `config.image_features`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1 images_shape = next(iter(config.image_features.values())).shape
image_key = image_keys[0] dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_input_h_w = ( dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -825,7 +772,7 @@ class VqVae(nn.Module):
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively. Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs. The vq_layer uses residual VQs.
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1), This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
as well as functions to help BeT training part in training phase 2. as well as functions to help BeT training part in training phase 2.
""" """
@@ -844,7 +791,7 @@ class VqVae(nn.Module):
) )
self.encoder = MLP( self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size, in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
hidden_channels=[ hidden_channels=[
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
@@ -856,7 +803,7 @@ class VqVae(nn.Module):
hidden_channels=[ hidden_channels=[
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim,
self.config.output_shapes["action"][0] * self.config.action_chunk_size, self.config.action_feature.shape[0] * self.config.action_chunk_size,
], ],
) )
@@ -872,9 +819,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action. # given latent vector, this function outputs the decoded action.
output = self.decoder(latent) output = self.decoder(latent)
if self.config.action_chunk_size == 1: if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
else: else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
def get_code(self, state): def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
This file is part of a VQ-BeT that utilizes code from the following repositories: This file is part of a VQ-BeT that utilizes code from the following repositories:
- Vector Quantize PyTorch code is licensed under the MIT License: - Vector Quantize PyTorch code is licensed under the MIT License:
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch Original source: https://github.com/lucidrains/vector-quantize-pytorch
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
Original source: https://github.com/karpathy/nanoGPT Original source: https://github.com/karpathy/nanoGPT
@@ -203,9 +203,9 @@ class GPT(nn.Module):
def forward(self, input, targets=None): def forward(self, input, targets=None):
device = input.device device = input.device
b, t, d = input.size() b, t, d = input.size()
assert ( assert t <= self.config.gpt_block_size, (
t <= self.config.gpt_block_size f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" )
# positional encodings that are added to the input embeddings # positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@@ -273,10 +273,10 @@ class GPT(nn.Module):
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params) str(inter_params)
) )
assert ( assert len(param_dict.keys() - union_params) == 0, (
len(param_dict.keys() - union_params) == 0 "parameters {} were not separated into either decay/no_decay set!".format(
), "parameters {} were not separated into either decay/no_decay set!".format( str(param_dict.keys() - union_params),
str(param_dict.keys() - union_params), )
) )
decay = [param_dict[pn] for pn in sorted(decay)] decay = [param_dict[pn] for pn in sorted(decay)]
@@ -289,7 +289,7 @@ class GPT(nn.Module):
This file is a part for Residual Vector Quantization that utilizes code from the following repository: This file is a part for Residual Vector Quantization that utilizes code from the following repository:
- Phil Wang's vector-quantize-pytorch implementation in PyTorch. - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch Original source: https://github.com/lucidrains/vector-quantize-pytorch
- The vector-quantize-pytorch code is licensed under the MIT License: - The vector-quantize-pytorch code is licensed under the MIT License:
@@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
# and the network should be able to reconstruct # and the network should be able to reconstruct
if quantize_dim < self.num_quantizers: if quantize_dim < self.num_quantizers:
assert ( assert self.quantize_dropout > 0.0, (
self.quantize_dropout > 0.0 "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" )
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
# get ready for gathering # get ready for gathering
@@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
all_indices = [] all_indices = []
if return_loss: if return_loss:
assert not torch.any( assert not torch.any(indices == -1), (
indices == -1 "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" )
ce_losses = [] ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
# only calculate orthogonal loss for the activated codes for this batch # only calculate orthogonal loss for the activated codes for this batch
if self.orthogonal_reg_active_codes_only: if self.orthogonal_reg_active_codes_only:
assert not ( assert not (is_multiheaded and self.separate_codebook_per_head), (
is_multiheaded and self.separate_codebook_per_head "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" )
unique_code_ids = torch.unique(embed_ind) unique_code_ids = torch.unique(embed_ind)
codebook = codebook[:, unique_code_ids] codebook = codebook[:, unique_code_ids]
@@ -999,9 +999,9 @@ def gumbel_sample(
ind = sampling_logits.argmax(dim=dim) ind = sampling_logits.argmax(dim=dim)
one_hot = F.one_hot(ind, size).type(dtype) one_hot = F.one_hot(ind, size).type(dtype)
assert not ( assert not (reinmax and not straight_through), (
reinmax and not straight_through "reinmax can only be turned on if using straight through gumbel softmax"
), "reinmax can only be turned on if using straight through gumbel softmax" )
if not straight_through or temperature <= 0.0 or not training: if not straight_through or temperature <= 0.0 or not training:
return ind, one_hot return ind, one_hot
@@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
self.gumbel_sample = gumbel_sample self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp self.sample_codebook_temp = sample_codebook_temp
assert not ( assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
use_ddp and num_codebooks > 1 and kmeans_init "kmeans init is not compatible with multiple codebooks in distributed environment for now"
), "kmeans init is not compatible with multiple codebooks in distributed environment for now" )
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
@@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
# calculate distributed variance # calculate distributed variance
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum") variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
distributed.all_reduce(variance_numer) distributed.all_reduce(variance_number)
batch_variance = variance_numer / num_vectors batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)

View File

@@ -0,0 +1,100 @@
import abc
from dataclasses import dataclass
import draccus
@dataclass
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@CameraConfig.register_subclass("opencv")
@dataclass
class OpenCVCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(0, 30, 640, 480)
OpenCVCameraConfig(0, 60, 640, 480)
OpenCVCameraConfig(0, 90, 640, 480)
OpenCVCameraConfig(0, 30, 1280, 720)
```
"""
camera_index: int
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@CameraConfig.register_subclass("intelrealsense")
@dataclass
class IntelRealSenseCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
```
"""
name: str | None = None
serial_number: int | None = None
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
mock: bool = False
def __post_init__(self):
# bool is stronger than is None, since it works with empty strings
if bool(self.name) and bool(self.serial_number):
raise ValueError(
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")

View File

@@ -11,13 +11,13 @@ import threading
import time import time
import traceback import traceback
from collections import Counter from collections import Counter
from dataclasses import dataclass, replace
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError, RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError, RobotDeviceNotConnectedError,
@@ -94,7 +94,10 @@ def save_images_from_cameras(
cameras = [] cameras = []
for cam_sn in serial_numbers: for cam_sn in serial_numbers:
print(f"{cam_sn=}") print(f"{cam_sn=}")
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock) config = IntelRealSenseCameraConfig(
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
)
camera = IntelRealSenseCamera(config)
camera.connect() camera.connect()
print( print(
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})" f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
@@ -149,51 +152,6 @@ def save_images_from_cameras(
camera.disconnect() camera.disconnect()
@dataclass
class IntelRealSenseCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(30, 640, 480)
IntelRealSenseCameraConfig(60, 640, 480)
IntelRealSenseCameraConfig(90, 640, 480)
IntelRealSenseCameraConfig(30, 1280, 720)
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(30, 640, 480, rotation=90)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class IntelRealSenseCamera: class IntelRealSenseCamera:
""" """
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras: The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
@@ -209,33 +167,35 @@ class IntelRealSenseCamera:
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
of the given camera will be used. of the given camera will be used.
Example of usage: Example of instantiating with a serial number:
```python ```python
# Instantiate with its serial number from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
camera = IntelRealSenseCamera(128422271347)
# Or by its name if it's unique config = IntelRealSenseCameraConfig(serial_number=128422271347)
camera = IntelRealSenseCamera.init_from_name("Intel RealSense D405") camera = IntelRealSenseCamera(config)
camera.connect() camera.connect()
color_image = camera.read() color_image = camera.read()
# when done using the camera, consider disconnecting # when done using the camera, consider disconnecting
camera.disconnect() camera.disconnect()
``` ```
Example of instantiating with a name if it's unique:
```
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
```
Example of changing default fps, width, height and color_mode: Example of changing default fps, width, height and color_mode:
```python ```python
camera = IntelRealSenseCamera(serial_number, fps=30, width=1280, height=720) config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480) # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
camera = connect()
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480, color_mode="bgr")
camera = connect()
``` ```
Example of returning depth: Example of returning depth:
```python ```python
camera = IntelRealSenseCamera(serial_number, use_depth=True) config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
camera = IntelRealSenseCamera(config)
camera.connect() camera.connect()
color_image, depth_map = camera.read() color_image, depth_map = camera.read()
``` ```
@@ -243,17 +203,13 @@ class IntelRealSenseCamera:
def __init__( def __init__(
self, self,
serial_number: int, config: IntelRealSenseCameraConfig,
config: IntelRealSenseCameraConfig | None = None,
**kwargs,
): ):
if config is None: self.config = config
config = IntelRealSenseCameraConfig() if config.name is not None:
self.serial_number = self.find_serial_number_from_name(config.name)
# Overwrite the config arguments using kwargs else:
config = replace(config, **kwargs) self.serial_number = config.serial_number
self.serial_number = serial_number
self.fps = config.fps self.fps = config.fps
self.width = config.width self.width = config.width
self.height = config.height self.height = config.height
@@ -285,8 +241,7 @@ class IntelRealSenseCamera:
elif config.rotation == 180: elif config.rotation == 180:
self.rotation = cv2.ROTATE_180 self.rotation = cv2.ROTATE_180
@classmethod def find_serial_number_from_name(self, name):
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
camera_infos = find_cameras() camera_infos = find_cameras()
camera_names = [cam["name"] for cam in camera_infos] camera_names = [cam["name"] for cam in camera_infos]
this_name_count = Counter(camera_names)[name] this_name_count = Counter(camera_names)[name]
@@ -299,13 +254,7 @@ class IntelRealSenseCamera:
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
cam_sn = name_to_serial_dict[name] cam_sn = name_to_serial_dict[name]
if config is None: return cam_sn
config = IntelRealSenseCameraConfig()
# Overwrite the config arguments using kwargs
config = replace(config, **kwargs)
return cls(serial_number=cam_sn, config=config, **kwargs)
def connect(self): def connect(self):
if self.is_connected: if self.is_connected:

View File

@@ -9,13 +9,13 @@ import platform
import shutil import shutil
import threading import threading
import time import time
from dataclasses import dataclass, replace
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError, RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError, RobotDeviceNotConnectedError,
@@ -126,7 +126,8 @@ def save_images_from_cameras(
print("Connecting cameras") print("Connecting cameras")
cameras = [] cameras = []
for cam_idx in camera_ids: for cam_idx in camera_ids:
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock) config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
camera = OpenCVCamera(config)
camera.connect() camera.connect()
print( print(
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, " f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
@@ -175,39 +176,6 @@ def save_images_from_cameras(
print(f"Images have been saved to {images_dir}") print(f"Images have been saved to {images_dir}")
@dataclass
class OpenCVCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(30, 640, 480)
OpenCVCameraConfig(60, 640, 480)
OpenCVCameraConfig(90, 640, 480)
OpenCVCameraConfig(30, 1280, 720)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class OpenCVCamera: class OpenCVCamera:
""" """
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
@@ -227,7 +195,10 @@ class OpenCVCamera:
Example of usage: Example of usage:
```python ```python
camera = OpenCVCamera(camera_index=0) from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
config = OpenCVCameraConfig(camera_index=0)
camera = OpenCVCamera(config)
camera.connect() camera.connect()
color_image = camera.read() color_image = camera.read()
# when done using the camera, consider disconnecting # when done using the camera, consider disconnecting
@@ -236,25 +207,16 @@ class OpenCVCamera:
Example of changing default fps, width, height and color_mode: Example of changing default fps, width, height and color_mode:
```python ```python
camera = OpenCVCamera(0, fps=30, width=1280, height=720) config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
camera = OpenCVCamera(0, fps=90, width=640, height=480) # Note: might error out open `camera.connect()` if these settings are not compatible with the camera
camera = connect()
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
camera = connect()
``` ```
""" """
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs): def __init__(self, config: OpenCVCameraConfig):
if config is None: self.config = config
config = OpenCVCameraConfig() self.camera_index = config.camera_index
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.camera_index = camera_index
self.port = None self.port = None
# Linux uses ports for connecting to cameras # Linux uses ports for connecting to cameras
@@ -266,7 +228,7 @@ class OpenCVCamera:
# Retrieve the camera index from a potentially symlinked path # Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port) self.camera_index = get_camera_index_from_unix_port(self.port)
else: else:
raise ValueError(f"Please check the provided camera_index: {camera_index}") raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
self.fps = config.fps self.fps = config.fps
self.width = config.width self.width = config.width

View File

@@ -2,6 +2,12 @@ from typing import Protocol
import numpy as np import numpy as np
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
# Defines a camera type # Defines a camera type
class Camera(Protocol): class Camera(Protocol):
@@ -9,3 +15,39 @@ class Camera(Protocol):
def read(self, temporary_color: str | None = None) -> np.ndarray: ... def read(self, temporary_color: str | None = None) -> np.ndarray: ...
def async_read(self) -> np.ndarray: ... def async_read(self) -> np.ndarray: ...
def disconnect(self): ... def disconnect(self): ...
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
cameras = {}
for key, cfg in camera_configs.items():
if cfg.type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
cameras[key] = IntelRealSenseCamera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras
def make_camera(camera_type, **kwargs) -> Camera:
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
config = OpenCVCameraConfig(**kwargs)
return OpenCVCamera(config)
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
config = IntelRealSenseCameraConfig(**kwargs)
return IntelRealSenseCamera(config)
else:
raise ValueError(f"The camera type '{camera_type}' is not valid.")

View File

@@ -0,0 +1,144 @@
import logging
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass
class ControlConfig(draccus.ChoiceRegistry):
pass
@ControlConfig.register_subclass("calibrate")
@dataclass
class CalibrateControlConfig(ControlConfig):
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
arms: list[str] | None = None
@ControlConfig.register_subclass("teleoperate")
@dataclass
class TeleoperateControlConfig(ControlConfig):
# Limit the maximum frames per second. By default, no limit.
fps: int | None = None
teleop_time_s: float | None = None
# Display all cameras on screen
display_cameras: bool = True
@ControlConfig.register_subclass("record")
@dataclass
class RecordControlConfig(ControlConfig):
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
warmup_time_s: int | float = 10
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Display all cameras on screen
display_cameras: bool = True
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("control.policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("control.policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@ControlConfig.register_subclass("replay")
@dataclass
class ReplayControlConfig(ControlConfig):
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# Index of the episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the dataset fps.
fps: int | None = None
# Use vocal synthesis to read events.
play_sounds: bool = True
@ControlConfig.register_subclass("remote_robot")
@dataclass
class RemoteRobotConfig(ControlConfig):
log_interval: int = 100
@dataclass
class ControlPipelineConfig:
robot: RobotConfig
control: ControlConfig
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["control.policy"]

View File

@@ -12,18 +12,15 @@ from functools import cache
import cv2 import cv2
import torch import torch
import tqdm
from deepdiff import DeepDiff from deepdiff import DeepDiff
from termcolor import colored from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.common.utils.utils import get_safe_torch_device, has_method
from lerobot.scripts.eval import get_pretrained_policy_path
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
@@ -35,7 +32,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s): def log_dt(shortname, dt_val_s):
nonlocal log_items, fps nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None: if fps is not None:
actual_fps = 1 / dt_val_s actual_fps = 1 / dt_val_s
if actual_fps < fps - 1: if actual_fps < fps - 1:
@@ -89,10 +86,6 @@ def is_headless():
return True return True
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def predict_action(observation, policy, device, use_amp): def predict_action(observation, policy, device, use_amp):
observation = copy(observation) observation = copy(observation)
with ( with (
@@ -161,26 +154,6 @@ def init_keyboard_listener():
return listener, events return listener, events
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
use_amp = hydra_cfg.use_amp
policy_fps = hydra_cfg.env.fps
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
return policy, policy_fps, device, use_amp
def warmup_record( def warmup_record(
robot, robot,
events, events,
@@ -209,6 +182,7 @@ def record_episode(
device, device,
use_amp, use_amp,
fps, fps,
single_task,
): ):
control_loop( control_loop(
robot=robot, robot=robot,
@@ -221,6 +195,7 @@ def record_episode(
use_amp=use_amp, use_amp=use_amp,
fps=fps, fps=fps,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task,
) )
@@ -233,9 +208,10 @@ def control_loop(
dataset: LeRobotDataset | None = None, dataset: LeRobotDataset | None = None,
events=None, events=None,
policy=None, policy=None,
device=None, device: torch.device | str | None = None,
use_amp=None, use_amp: bool | None = None,
fps=None, fps: int | None = None,
single_task: str | None = None,
): ):
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
if not robot.is_connected: if not robot.is_connected:
@@ -250,9 +226,15 @@ def control_loop(
if teleoperate and policy is not None: if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.") raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
while timestamp < control_time_s: while timestamp < control_time_s:
@@ -271,7 +253,7 @@ def control_loop(
action = {"action": action} action = {"action": action}
if dataset is not None: if dataset is not None:
frame = {**observation, **action} frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame) dataset.add_frame(frame)
if display_cameras and not is_headless(): if display_cameras and not is_headless():
@@ -293,24 +275,18 @@ def control_loop(
break break
def reset_environment(robot, events, reset_time_s): def reset_environment(robot, events, reset_time_s, fps):
# TODO(rcadene): refactor warmup_record and reset_environment # TODO(rcadene): refactor warmup_record and reset_environment
# TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
timestamp = 0 control_loop(
start_vencod_t = time.perf_counter() robot=robot,
control_time_s=reset_time_s,
# Wait if necessary events=events,
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: fps=fps,
while timestamp < reset_time_s: teleoperate=True,
time.sleep(1) )
timestamp = time.perf_counter() - start_vencod_t
pbar.update(1)
if events["exit_early"]:
events["exit_early"] = False
break
def stop_recording(robot, listener, display_cameras): def stop_recording(robot, listener, display_cameras):
@@ -324,21 +300,21 @@ def stop_recording(robot, listener, display_cameras):
cv2.destroyAllWindows() cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy): def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/") _, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy # either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy # or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing # Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy is None: if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError( raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided." f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
) )
# Check if dataset_name does not start with "eval_" but policy is provided # Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy is not None: if not dataset_name.startswith("eval_") and policy_cfg is not None:
raise ValueError( raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})." f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
) )

View File

@@ -0,0 +1,27 @@
import abc
from dataclasses import dataclass
import draccus
@dataclass
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@MotorsBusConfig.register_subclass("dynamixel")
@dataclass
class DynamixelMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False
@MotorsBusConfig.register_subclass("feetech")
@dataclass
class FeetechMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False

View File

@@ -8,6 +8,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import tqdm import tqdm
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
@@ -241,7 +242,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum): class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0 DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100] # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1 LINEAR = 1
@@ -252,7 +253,6 @@ class JointOutOfRangeError(Exception):
class DynamixelMotorsBus: class DynamixelMotorsBus:
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
""" """
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20). the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
@@ -274,10 +274,11 @@ class DynamixelMotorsBus:
motor_index = 6 motor_index = 6
motor_model = "xl330-m288" motor_model = "xl330-m288"
motors_bus = DynamixelMotorsBus( config = DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751", port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)}, motors={motor_name: (motor_index, motor_model)},
) )
motors_bus = DynamixelMotorsBus(config)
motors_bus.connect() motors_bus.connect()
position = motors_bus.read("Present_Position") position = motors_bus.read("Present_Position")
@@ -293,23 +294,14 @@ class DynamixelMotorsBus:
def __init__( def __init__(
self, self,
port: str, config: DynamixelMotorsBusConfig,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
mock=False,
): ):
self.port = port self.port = config.port
self.motors = motors self.motors = config.motors
self.mock = mock self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION) self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None self.port_handler = None
self.packet_handler = None self.packet_handler = None
@@ -618,7 +610,7 @@ class DynamixelMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Substract the homing offsets to come back to actual motor range of values # Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary. # which can be arbitrary.
values[i] -= homing_offset values[i] -= homing_offset

View File

@@ -8,6 +8,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import tqdm import tqdm
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
@@ -220,7 +221,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum): class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0 DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100] # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1 LINEAR = 1
@@ -252,10 +253,11 @@ class FeetechMotorsBus:
motor_index = 6 motor_index = 6
motor_model = "sts3215" motor_model = "sts3215"
motors_bus = FeetechMotorsBus( config = FeetechMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751", port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)}, motors={motor_name: (motor_index, motor_model)},
) )
motors_bus = FeetechMotorsBus(config)
motors_bus.connect() motors_bus.connect()
position = motors_bus.read("Present_Position") position = motors_bus.read("Present_Position")
@@ -271,23 +273,14 @@ class FeetechMotorsBus:
def __init__( def __init__(
self, self,
port: str, config: FeetechMotorsBusConfig,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
mock=False,
): ):
self.port = port self.port = config.port
self.motors = motors self.motors = config.motors
self.mock = mock self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION) self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None self.port_handler = None
self.packet_handler = None self.packet_handler = None
@@ -598,7 +591,7 @@ class FeetechMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Substract the homing offsets to come back to actual motor range of values # Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary. # which can be arbitrary.
values[i] -= homing_offset values[i] -= homing_offset
@@ -639,7 +632,7 @@ class FeetechMotorsBus:
track["prev"][idx] = values[i] track["prev"][idx] = values[i]
continue continue
# Detect a full rotation occured # Detect a full rotation occurred
if abs(track["prev"][idx] - values[i]) > 2048: if abs(track["prev"][idx] - values[i]) > 2048:
# Position went below 0 and got reset to 4095 # Position went below 0 and got reset to 4095
if track["prev"][idx] < values[i]: if track["prev"][idx] < values[i]:
@@ -724,6 +717,10 @@ class FeetechMotorsBus:
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers: if data_name not in self.group_readers:
# Very Important to flush the buffer!
self.port_handler.ser.reset_output_buffer()
self.port_handler.ser.reset_input_buffer()
# create new group reader # create new group reader
self.group_readers[group_key] = scs.GroupSyncRead( self.group_readers[group_key] = scs.GroupSyncRead(
self.port_handler, self.packet_handler, addr, bytes self.port_handler, self.packet_handler, addr, bytes

View File

@@ -1,5 +1,11 @@
from typing import Protocol from typing import Protocol
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
class MotorsBus(Protocol): class MotorsBus(Protocol):
def motor_names(self): ... def motor_names(self): ...
@@ -8,3 +14,40 @@ class MotorsBus(Protocol):
def revert_calibration(self): ... def revert_calibration(self): ...
def read(self): ... def read(self): ...
def write(self): ... def write(self): ...
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
motors_buses = {}
for key, cfg in motors_bus_configs.items():
if cfg.type == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
motors_buses[key] = DynamixelMotorsBus(cfg)
elif cfg.type == "feetech":
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
motors_buses[key] = FeetechMotorsBus(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return motors_buses
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
config = DynamixelMotorsBusConfig(**kwargs)
return DynamixelMotorsBus(config)
elif motor_type == "feetech":
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
config = FeetechMotorsBusConfig(**kwargs)
return FeetechMotorsBus(config)
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")

View File

@@ -0,0 +1,599 @@
import abc
from dataclasses import dataclass, field
from typing import Sequence
import draccus
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
@dataclass
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
@dataclass
class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
mock: bool = False
def __post_init__(self):
if self.mock:
for arm in self.leader_arms.values():
if not arm.mock:
arm.mock = True
for arm in self.follower_arms.values():
if not arm.mock:
arm.mock = True
for cam in self.cameras.values():
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
@RobotConfig.register_subclass("aloha")
@dataclass
class AlohaRobotConfig(ManipulatorRobotConfig):
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
# simply update this path to ".cache/calibration/aloha"
calibration_dir: str = ".cache/calibration/aloha_default"
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_left",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
"right": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_right",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_left",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_right",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
}
)
# Troubleshooting: If one of your IntelRealSense cameras freeze during
# data recording due to bandwidth limit, you might need to plug the camera
# on another USB hub or PCIe card.
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"cam_high": IntelRealSenseCameraConfig(
serial_number=128422271347,
fps=30,
width=640,
height=480,
),
"cam_low": IntelRealSenseCameraConfig(
serial_number=130322270656,
fps=30,
width=640,
height=480,
),
"cam_left_wrist": IntelRealSenseCameraConfig(
serial_number=218622272670,
fps=30,
width=640,
height=480,
),
"cam_right_wrist": IntelRealSenseCameraConfig(
serial_number=130322272300,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("koch")
@dataclass
class KochRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("koch_bimanual")
@dataclass
class KochBimanualRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch_bimanual"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("moss")
@dataclass
class MossRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/moss"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/so100"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("stretch")
@dataclass
class StretchRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
camera_index="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": IntelRealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": IntelRealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Network Configuration
ip: str = "192.168.0.193"
port: int = 5555
video_port: int = 5556
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"front": OpenCVCameraConfig(
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
),
"wrist": OpenCVCameraConfig(
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
),
}
)
calibration_dir: str = ".cache/calibration/lekiwi"
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0077581",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/ttyACM0",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
"left_wheel": (7, "sts3215"),
"back_wheel": (8, "sts3215"),
"right_wheel": (9, "sts3215"),
},
),
}
)
teleop_keys: dict[str, str] = field(
default_factory=lambda: {
# Movement
"forward": "w",
"backward": "s",
"left": "a",
"right": "d",
"rotate_left": "z",
"rotate_right": "x",
# Speed control
"speed_up": "r",
"speed_down": "f",
# quit teleop
"quit": "q",
}
)
mock: bool = False

View File

@@ -87,7 +87,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position. # is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
@@ -115,7 +115,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml? # TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
if robot_type in ["aloha"] and "gripper" in arm.motor_names: if robot_type in ["aloha"] and "gripper" in arm.motor_names:
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100] # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
calib_idx = arm.motor_names.index("gripper") calib_idx = arm.motor_names.index("gripper")
calib_mode[calib_idx] = CalibrationMode.LINEAR.name calib_mode[calib_idx] = CalibrationMode.LINEAR.name

View File

@@ -1,9 +0,0 @@
import hydra
from omegaconf import DictConfig
from lerobot.common.robot_devices.robots.utils import Robot
def make_robot(cfg: DictConfig) -> Robot:
robot = hydra.utils.instantiate(cfg)
return robot

View File

@@ -443,7 +443,7 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position. # is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))

View File

@@ -0,0 +1,210 @@
import base64
import json
import threading
import time
from pathlib import Path
import cv2
import zmq
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
def setup_zmq_sockets(config):
context = zmq.Context()
cmd_socket = context.socket(zmq.PULL)
cmd_socket.setsockopt(zmq.CONFLATE, 1)
cmd_socket.bind(f"tcp://*:{config.port}")
video_socket = context.socket(zmq.PUSH)
video_socket.setsockopt(zmq.CONFLATE, 1)
video_socket.bind(f"tcp://*:{config.video_port}")
return context, cmd_socket, video_socket
def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
while not stop_event.is_set():
local_dict = {}
for name, cam in cameras.items():
frame = cam.async_read()
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
if ret:
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
else:
local_dict[name] = ""
with images_lock:
latest_images_dict.update(local_dict)
time.sleep(0.01)
def calibrate_follower_arm(motors_bus, calib_dir_str):
"""
Calibrates the follower arm. Attempts to load an existing calibration file;
if not found, runs manual calibration and saves the result.
"""
calib_dir = Path(calib_dir_str)
calib_dir.mkdir(parents=True, exist_ok=True)
calib_file = calib_dir / "main_follower.json"
try:
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
except ImportError:
print("[WARNING] Calibration function not available. Skipping calibration.")
return
if calib_file.exists():
with open(calib_file) as f:
calibration = json.load(f)
print(f"[INFO] Loaded calibration from {calib_file}")
else:
print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f:
json.dump(calibration, f)
try:
motors_bus.set_calibration(calibration)
print("[INFO] Applied calibration for follower arm.")
except Exception as e:
print(f"[WARNING] Could not apply calibration: {e}")
def run_lekiwi(robot_config):
"""
Runs the LeKiwi robot:
- Sets up cameras and connects them.
- Initializes the follower arm motors.
- Calibrates the follower arm if necessary.
- Creates ZeroMQ sockets for receiving commands and streaming observations.
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
"""
# Import helper functions and classes
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
# Initialize cameras from the robot configuration.
cameras = make_cameras_from_configs(robot_config.cameras)
for cam in cameras.values():
cam.connect()
# Initialize the motors bus using the follower arm configuration.
motor_config = robot_config.follower_arms.get("main")
if motor_config is None:
print("[ERROR] Follower arm 'main' configuration not found.")
return
motors_bus = FeetechMotorsBus(motor_config)
motors_bus.connect()
# Calibrate the follower arm.
calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
# Create the LeKiwi robot instance.
robot = LeKiwi(motors_bus)
# Define the expected arm motor IDs.
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
# Disable torque for each arm motor.
for motor in arm_motor_ids:
motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
# Set up ZeroMQ sockets.
context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
# Start the camera capture thread.
latest_images_dict = {}
images_lock = threading.Lock()
stop_event = threading.Event()
cam_thread = threading.Thread(
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
)
cam_thread.start()
last_cmd_time = time.time()
print("LeKiwi robot server started. Waiting for commands...")
try:
while True:
loop_start_time = time.time()
# Process incoming commands (non-blocking).
while True:
try:
msg = cmd_socket.recv_string(zmq.NOBLOCK)
except zmq.Again:
break
try:
data = json.loads(msg)
# Process arm position commands.
if "arm_positions" in data:
arm_positions = data["arm_positions"]
if not isinstance(arm_positions, list):
print(f"[ERROR] Invalid arm_positions: {arm_positions}")
elif len(arm_positions) < len(arm_motor_ids):
print(
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
)
else:
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
motors_bus.write("Goal_Position", pos, motor)
# Process wheel (base) commands.
if "raw_velocity" in data:
raw_command = data["raw_velocity"]
# Expect keys: "left_wheel", "back_wheel", "right_wheel".
command_speeds = [
int(raw_command.get("left_wheel", 0)),
int(raw_command.get("back_wheel", 0)),
int(raw_command.get("right_wheel", 0)),
]
robot.set_velocity(command_speeds)
last_cmd_time = time.time()
except Exception as e:
print(f"[ERROR] Parsing message failed: {e}")
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
now = time.time()
if now - last_cmd_time > 0.5:
robot.stop()
last_cmd_time = now
# Read current wheel speeds from the robot.
current_velocity = robot.read_velocity()
# Read the follower arm state from the motors bus.
follower_arm_state = []
for motor in arm_motor_ids:
try:
pos = motors_bus.read("Present_Position", motor)
# Convert the position to a float (or use as is if already numeric).
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
except Exception as e:
print(f"[ERROR] Reading motor {motor} failed: {e}")
# Get the latest camera images.
with images_lock:
images_dict_copy = dict(latest_images_dict)
# Build the observation dictionary.
observation = {
"images": images_dict_copy,
"present_speed": current_velocity,
"follower_arm_state": follower_arm_state,
}
# Send the observation over the video socket.
video_socket.send_string(json.dumps(observation))
# Ensure a short sleep to avoid overloading the CPU.
elapsed = time.time() - loop_start_time
time.sleep(
max(0.033 - elapsed, 0)
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
except KeyboardInterrupt:
print("Shutting down LeKiwi server.")
finally:
stop_event.set()
cam_thread.join()
robot.stop()
motors_bus.disconnect()
cmd_socket.close()
video_socket.close()
context.term()

View File

@@ -8,15 +8,14 @@ import json
import logging import logging
import time import time
import warnings import warnings
from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import Sequence
import numpy as np import numpy as np
import torch import torch
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
@@ -41,65 +40,26 @@ def ensure_safe_goal_position(
return safe_goal_pos return safe_goal_pos
@dataclass
class ManipulatorRobotConfig:
"""
Example of usage:
```python
ManipulatorRobotConfig()
```
"""
# Define all components of the robot
robot_type: str = "koch"
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
cameras: dict[str, Camera] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
def __setattr__(self, prop: str, val):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(val):
raise ValueError(
f"len(max_relative_target)={len(val)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
super().__setattr__(prop, val)
def __post_init__(self):
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
class ManipulatorRobot: class ManipulatorRobot:
# TODO(rcadene): Implement force feedback # TODO(rcadene): Implement force feedback
"""This class allows to control any manipulator robot of various number of motors. """This class allows to control any manipulator robot of various number of motors.
Non exaustive list of robots: Non exhaustive list of robots:
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
by Alexander Koch from [Tau Robotics](https://tau-robotics.com) by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
- [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics - [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics
Example of highest frequency teleoperation without camera: Example of instantiation, a pre-defined robot config is required:
```python
robot = ManipulatorRobot(KochRobotConfig())
```
Example of overwriting motors during instantiation:
```python ```python
# Defines how to communicate with the motors of the leader and follower arms # Defines how to communicate with the motors of the leader and follower arms
leader_arms = { leader_arms = {
"main": DynamixelMotorsBus( "main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751", port="/dev/tty.usbmodem575E0031751",
motors={ motors={
# name: (index, model) # name: (index, model)
@@ -113,7 +73,7 @@ class ManipulatorRobot:
), ),
} }
follower_arms = { follower_arms = {
"main": DynamixelMotorsBus( "main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0032081", port="/dev/tty.usbmodem575E0032081",
motors={ motors={
# name: (index, model) # name: (index, model)
@@ -126,35 +86,11 @@ class ManipulatorRobot:
}, },
), ),
} }
robot = ManipulatorRobot( robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms)
robot_type="koch", robot = ManipulatorRobot(robot_config)
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
)
# Connect motors buses and cameras if any (Required)
robot.connect()
while True:
robot.teleop_step()
``` ```
Example of highest frequency data collection without camera: Example of overwriting cameras during instantiation:
```python
# Assumes leader and follower arms have been instantiated already (see first example)
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
```
Example of highest frequency data collection with cameras:
```python ```python
# Defines how to communicate with 2 cameras connected to the computer. # Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the laptop and the phone (connected in USB to the laptop) # Here, the webcam of the laptop and the phone (connected in USB to the laptop)
@@ -164,31 +100,28 @@ class ManipulatorRobot:
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480), "laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480), "phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
} }
robot = ManipulatorRobot(KochRobotConfig(cameras=cameras))
```
# Assumes leader and follower arms have been instantiated already (see first example) Once the robot is instantiated, connect motors buses and cameras if any (Required):
robot = ManipulatorRobot( ```python
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
robot.connect() robot.connect()
```
Example of highest frequency teleoperation, which doesn't require cameras:
```python
while True:
robot.teleop_step()
```
Example of highest frequency data collection from motors and cameras (if any):
```python
while True: while True:
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
``` ```
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency): Example of controlling the robot with a policy:
```python ```python
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
robot = ManipulatorRobot(
robot_type="koch",
calibration_dir=".cache/calibration/koch",
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
robot.connect()
while True: while True:
# Uses the follower arms and cameras to capture an observation # Uses the follower arms and cameras to capture an observation
observation = robot.capture_observation() observation = robot.capture_observation()
@@ -209,20 +142,14 @@ class ManipulatorRobot:
def __init__( def __init__(
self, self,
config: ManipulatorRobotConfig | None = None, config: ManipulatorRobotConfig,
calibration_dir: Path = ".cache/calibration/koch",
**kwargs,
): ):
if config is None: self.config = config
config = ManipulatorRobotConfig() self.robot_type = self.config.type
# Overwrite config arguments using kwargs self.calibration_dir = Path(self.config.calibration_dir)
self.config = replace(config, **kwargs) self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.calibration_dir = Path(calibration_dir) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
self.robot_type = self.config.robot_type
self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
@@ -302,7 +229,7 @@ class ManipulatorRobot:
if self.robot_type in ["koch", "koch_bimanual", "aloha"]: if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
elif self.robot_type in ["so100", "moss"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.feetech import TorqueMode
# We assume that at connection time, arms are in a rest position, and torque can # We assume that at connection time, arms are in a rest position, and torque can
@@ -319,7 +246,7 @@ class ManipulatorRobot:
self.set_koch_robot_preset() self.set_koch_robot_preset()
elif self.robot_type == "aloha": elif self.robot_type == "aloha":
self.set_aloha_robot_preset() self.set_aloha_robot_preset()
elif self.robot_type in ["so100", "moss"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
self.set_so100_robot_preset() self.set_so100_robot_preset()
# Enable torque on all motors of the follower arms # Enable torque on all motors of the follower arms
@@ -372,7 +299,7 @@ class ManipulatorRobot:
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
elif self.robot_type in ["so100", "moss"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import ( from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration, run_arm_manual_calibration,
) )
@@ -421,7 +348,7 @@ class ManipulatorRobot:
set_operating_mode_(self.follower_arms[name]) set_operating_mode_(self.follower_arms[name])
# Set better PID values to close the gap between recorded states and actions # Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex") self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex") self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex") self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
@@ -573,7 +500,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state obs_dict["observation.state"] = state
action_dict["action"] = action action_dict["action"] = action
@@ -613,7 +540,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch # Populate output dictionaries and format to pytorch
obs_dict = {} obs_dict = {}
obs_dict["observation.state"] = state obs_dict["observation.state"] = state
for name in self.cameras: for name in self.cameras:

View File

@@ -0,0 +1,691 @@
import base64
import json
import os
import sys
from pathlib import Path
import cv2
import numpy as np
import torch
import zmq
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import TorqueMode
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
PYNPUT_AVAILABLE = True
try:
# Only import if there's a valid X server or if we're not on a Pi
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
print("No DISPLAY set. Skipping pynput import.")
raise ImportError("pynput blocked intentionally due to no display.")
from pynput import keyboard
except ImportError:
keyboard = None
PYNPUT_AVAILABLE = False
except Exception as e:
keyboard = None
PYNPUT_AVAILABLE = False
print(f"Could not import pynput: {e}")
class MobileManipulator:
"""
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
The robot includes a three omniwheel mobile base and a remote follower arm.
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
forwarded to the remote follower arm (after applying a safety clamp).
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
"""
def __init__(self, config: LeKiwiRobotConfig):
"""
Expected keys in config:
- ip, port, video_port for the remote connection.
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
"""
self.robot_type = config.type
self.config = config
self.remote_ip = config.ip
self.remote_port = config.port
self.remote_port_video = config.video_port
self.calibration_dir = Path(self.config.calibration_dir)
self.logs = {}
self.teleop_keys = self.config.teleop_keys
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False
self.last_frames = {}
self.last_present_speed = {}
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
# Define three speed levels and a current index
self.speed_levels = [
{"xy": 0.1, "theta": 30}, # slow
{"xy": 0.2, "theta": 60}, # medium
{"xy": 0.3, "theta": 90}, # fast
]
self.speed_index = 0 # Start at slow
# ZeroMQ context and sockets.
self.context = None
self.cmd_socket = None
self.video_socket = None
# Keyboard state for base teleoperation.
self.running = True
self.pressed_keys = {
"forward": False,
"backward": False,
"left": False,
"right": False,
"rotate_left": False,
"rotate_right": False,
}
if PYNPUT_AVAILABLE:
print("pynput is available - enabling local keyboard listener.")
self.listener = keyboard.Listener(
on_press=self.on_press,
on_release=self.on_release,
)
self.listener.start()
else:
print("pynput not available - skipping local keyboard listener.")
self.listener = None
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
follower_arm_names = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
observations = ["x_mm", "y_mm", "theta"]
combined_names = follower_arm_names + observations
return {
"action": {
"dtype": "float32",
"shape": (len(combined_names),),
"names": combined_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(combined_names),),
"names": combined_names,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
@property
def available_arms(self):
available = []
for name in self.leader_arms:
available.append(get_arm_id(name, "leader"))
for name in self.follower_arms:
available.append(get_arm_id(name, "follower"))
return available
def on_press(self, key):
try:
# Movement
if key.char == self.teleop_keys["forward"]:
self.pressed_keys["forward"] = True
elif key.char == self.teleop_keys["backward"]:
self.pressed_keys["backward"] = True
elif key.char == self.teleop_keys["left"]:
self.pressed_keys["left"] = True
elif key.char == self.teleop_keys["right"]:
self.pressed_keys["right"] = True
elif key.char == self.teleop_keys["rotate_left"]:
self.pressed_keys["rotate_left"] = True
elif key.char == self.teleop_keys["rotate_right"]:
self.pressed_keys["rotate_right"] = True
# Quit teleoperation
elif key.char == self.teleop_keys["quit"]:
self.running = False
return False
# Speed control
elif key.char == self.teleop_keys["speed_up"]:
self.speed_index = min(self.speed_index + 1, 2)
print(f"Speed index increased to {self.speed_index}")
elif key.char == self.teleop_keys["speed_down"]:
self.speed_index = max(self.speed_index - 1, 0)
print(f"Speed index decreased to {self.speed_index}")
except AttributeError:
# e.g., if key is special like Key.esc
if key == keyboard.Key.esc:
self.running = False
return False
def on_release(self, key):
try:
if hasattr(key, "char"):
if key.char == self.teleop_keys["forward"]:
self.pressed_keys["forward"] = False
elif key.char == self.teleop_keys["backward"]:
self.pressed_keys["backward"] = False
elif key.char == self.teleop_keys["left"]:
self.pressed_keys["left"] = False
elif key.char == self.teleop_keys["right"]:
self.pressed_keys["right"] = False
elif key.char == self.teleop_keys["rotate_left"]:
self.pressed_keys["rotate_left"] = False
elif key.char == self.teleop_keys["rotate_right"]:
self.pressed_keys["rotate_right"] = False
except AttributeError:
pass
def connect(self):
if not self.leader_arms:
raise ValueError("MobileManipulator has no leader arm to connect.")
for name in self.leader_arms:
print(f"Connecting {name} leader arm.")
self.calibrate_leader()
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
self.context = zmq.Context()
self.cmd_socket = self.context.socket(zmq.PUSH)
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
self.cmd_socket.connect(connection_string)
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
self.video_socket = self.context.socket(zmq.PULL)
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
self.video_socket.connect(video_connection)
self.video_socket.setsockopt(zmq.CONFLATE, 1)
print(
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
)
self.is_connected = True
def load_or_run_calibration_(self, name, arm, arm_type):
arm_id = get_arm_id(name, arm_type)
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
if arm_calib_path.exists():
with open(arm_calib_path) as f:
calibration = json.load(f)
else:
print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
return calibration
def calibrate_leader(self):
for name, arm in self.leader_arms.items():
# Connect the bus
arm.connect()
# Disable torque on all motors
for motor_id in arm.motors:
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
# Now run calibration
calibration = self.load_or_run_calibration_(name, arm, "leader")
arm.set_calibration(calibration)
def calibrate_follower(self):
for name, bus in self.follower_arms.items():
bus.connect()
# Disable torque on all motors
for motor_id in bus.motors:
bus.write("Torque_Enable", 0, motor_id)
# Then filter out wheels
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
if not arm_only_dict:
continue
original_motors = bus.motors
bus.motors = arm_only_dict
calibration = self.load_or_run_calibration_(name, bus, "follower")
bus.set_calibration(calibration)
bus.motors = original_motors
def _get_data(self):
"""
Polls the video socket for up to 15 ms. If data arrives, decode only
the *latest* message, returning frames, speed, and arm state. If
nothing arrives for any field, use the last known values.
"""
frames = {}
present_speed = {}
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
# Poll up to 15 ms
poller = zmq.Poller()
poller.register(self.video_socket, zmq.POLLIN)
socks = dict(poller.poll(15))
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
# No new data arrived → reuse ALL old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
# Drain all messages, keep only the last
last_msg = None
while True:
try:
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
last_msg = obs_string
except zmq.Again:
break
if not last_msg:
# No new message → also reuse old
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
# Decode only the final message
try:
observation = json.loads(last_msg)
images_dict = observation.get("images", {})
new_speed = observation.get("present_speed", {})
new_arm_state = observation.get("follower_arm_state", None)
# Convert images
for cam_name, image_b64 in images_dict.items():
if image_b64:
jpg_data = base64.b64decode(image_b64)
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if frame_candidate is not None:
frames[cam_name] = frame_candidate
# If remote_arm_state is None and frames is None there is no message then use the previous message
if new_arm_state is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = new_speed
self.last_present_speed = new_speed
else:
frames = self.last_frames
remote_arm_state_tensor = self.last_remote_arm_state
present_speed = self.last_present_speed
except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return frames, present_speed, remote_arm_state_tensor
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
state_tensor = torch.zeros(3, dtype=torch.int32)
if present_speed:
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
if "1" in decoded:
state_tensor[0] = decoded["1"]
if "2" in decoded:
state_tensor[1] = decoded["2"]
if "3" in decoded:
state_tensor[2] = decoded["3"]
return state_tensor
def teleop_step(
self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
# Prepare to assign the position of the leader to the follower
arm_positions = []
for name in self.leader_arms:
pos = self.leader_arms[name].read("Present_Position")
pos_tensor = torch.from_numpy(pos).float()
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
arm_positions.extend(pos_tensor.tolist())
# (The rest of your code for generating wheel commands remains unchanged)
x_cmd = 0.0 # m/s forward/backward
y_cmd = 0.0 # m/s lateral
theta_cmd = 0.0 # deg/s rotation
if self.pressed_keys["forward"]:
x_cmd += xy_speed
if self.pressed_keys["backward"]:
x_cmd -= xy_speed
if self.pressed_keys["left"]:
y_cmd += xy_speed
if self.pressed_keys["right"]:
y_cmd -= xy_speed
if self.pressed_keys["rotate_left"]:
theta_cmd += theta_speed
if self.pressed_keys["rotate_right"]:
theta_cmd -= theta_speed
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
self.cmd_socket.send_string(json.dumps(message))
if not record_data:
return
obs_dict = self.capture_observation()
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
wheel_velocity_mm = (
wheel_velocity_tuple[0] * 1000.0,
wheel_velocity_tuple[1] * 1000.0,
wheel_velocity_tuple[2],
)
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
action_dict = {"action": action_tensor}
return obs_dict, action_dict
def capture_observation(self) -> dict:
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and a camera frame.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self.wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
obs_dict = {"observation.state": combined_state_tensor}
# Loop over each configured camera
for cam_name, cam in self.cameras.items():
frame = frames.get(cam_name, None)
if frame is None:
# Create a black image using the camera's configured width, height, and channels
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
if not self.is_connected:
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
# Ensure the action tensor has at least 9 elements:
# - First 6: arm positions.
# - Last 3: base commands.
if action.numel() < 9:
# Pad with zeros if there are not enough elements.
padded = torch.zeros(9, dtype=action.dtype)
padded[: action.numel()] = action
action = padded
# Extract arm and base actions.
arm_actions = action[:6].flatten()
base_actions = action[6:].flatten()
x_cmd_mm = base_actions[0].item() # mm/s
y_cmd_mm = base_actions[1].item() # mm/s
theta_cmd = base_actions[2].item() # deg/s
# Convert mm/s to m/s for the kinematics calculations.
x_cmd = x_cmd_mm / 1000.0 # m/s
y_cmd = y_cmd_mm / 1000.0 # m/s
# Compute wheel commands from body commands.
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
arm_positions_list = arm_actions.tolist()
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
self.cmd_socket.send_string(json.dumps(message))
return action
def print_logs(self):
pass
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError("Not connected.")
if self.cmd_socket:
stop_cmd = {
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
"arm_positions": {},
}
self.cmd_socket.send_string(json.dumps(stop_cmd))
self.cmd_socket.close()
if self.video_socket:
self.video_socket.close()
if self.context:
self.context.term()
if PYNPUT_AVAILABLE:
self.listener.stop()
self.is_connected = False
print("[INFO] Disconnected from remote robot.")
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()
if PYNPUT_AVAILABLE:
self.listener.stop()
@staticmethod
def degps_to_raw(degps: float) -> int:
steps_per_deg = 4096.0 / 360.0
speed_in_steps = abs(degps) * steps_per_deg
speed_int = int(round(speed_in_steps))
if speed_int > 0x7FFF:
speed_int = 0x7FFF
if degps < 0:
return speed_int | 0x8000
else:
return speed_int & 0x7FFF
@staticmethod
def raw_to_degps(raw_speed: int) -> float:
steps_per_deg = 4096.0 / 360.0
magnitude = raw_speed & 0x7FFF
degps = magnitude / steps_per_deg
if raw_speed & 0x8000:
degps = -degps
return degps
def body_to_wheel_raw(
self,
x_cmd: float,
y_cmd: float,
theta_cmd: float,
wheel_radius: float = 0.05,
base_radius: float = 0.125,
max_raw: int = 3000,
) -> dict:
"""
Convert desired body-frame velocities into wheel raw commands.
Parameters:
x_cmd : Linear velocity in x (m/s).
y_cmd : Linear velocity in y (m/s).
theta_cmd : Rotational velocity (deg/s).
wheel_radius: Radius of each wheel (meters).
base_radius : Distance from the center of rotation to each wheel (meters).
max_raw : Maximum allowed raw command (ticks) per wheel.
Returns:
A dictionary with wheel raw commands:
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
Notes:
- Internally, the method converts theta_cmd to rad/s for the kinematics.
- The raw command is computed from the wheels angular speed in deg/s
using degps_to_raw(). If any command exceeds max_raw, all commands
are scaled down proportionally.
"""
# Convert rotational velocity from deg/s to rad/s.
theta_rad = theta_cmd * (np.pi / 180.0)
# Create the body velocity vector [x, y, theta_rad].
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
# Build the kinematic matrix: each row maps body velocities to a wheels linear speed.
# The third column (base_radius) accounts for the effect of rotation.
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Compute each wheels linear speed (m/s) and then its angular speed (rad/s).
wheel_linear_speeds = m.dot(velocity_vector)
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
# Convert wheel angular speeds from rad/s to deg/s.
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
# Scaling
steps_per_deg = 4096.0 / 360.0
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
max_raw_computed = max(raw_floats)
if max_raw_computed > max_raw:
scale = max_raw / max_raw_computed
wheel_degps = wheel_degps * scale
# Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
def wheel_raw_to_body(
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
) -> tuple:
"""
Convert wheel raw command feedback back into body-frame velocities.
Parameters:
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
wheel_radius: Radius of each wheel (meters).
base_radius : Distance from the robot center to each wheel (meters).
Returns:
A tuple (x_cmd, y_cmd, theta_cmd) where:
x_cmd : Linear velocity in x (m/s).
y_cmd : Linear velocity in y (m/s).
theta_cmd : Rotational velocity in deg/s.
"""
# Extract the raw values in order.
raw_list = [
int(wheel_raw.get("left_wheel", 0)),
int(wheel_raw.get("back_wheel", 0)),
int(wheel_raw.get("right_wheel", 0)),
]
# Convert each raw command back to an angular speed in deg/s.
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
# Convert from deg/s to rad/s.
wheel_radps = wheel_degps * (np.pi / 180.0)
# Compute each wheels linear speed (m/s) from its angular speed.
wheel_linear_speeds = wheel_radps * wheel_radius
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
m_inv = np.linalg.inv(m)
velocity_vector = m_inv.dot(wheel_linear_speeds)
x_cmd, y_cmd, theta_rad = velocity_vector
theta_cmd = theta_rad * (180.0 / np.pi)
return (x_cmd, y_cmd, theta_cmd)
class LeKiwi:
def __init__(self, motor_bus):
"""
Initializes the LeKiwi with Feetech motors bus.
"""
self.motor_bus = motor_bus
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
# Initialize motors in velocity mode.
self.motor_bus.write("Lock", 0)
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
self.motor_bus.write("Lock", 1)
print("Motors set to velocity mode.")
def read_velocity(self):
"""
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
"""
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
return {
"left_wheel": int(raw_speeds[0]),
"back_wheel": int(raw_speeds[1]),
"right_wheel": int(raw_speeds[2]),
}
def set_velocity(self, command_speeds):
"""
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
The order of speeds must correspond to self.motor_ids.
"""
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
def stop(self):
"""Stops the robot by setting all motor speeds to zero."""
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
print("Motors stopped.")

Some files were not shown because too many files have changed in this diff Show More