Compare commits
33 Commits
chore/bump
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
daa1480a91 | ||
|
|
71ec721e48 | ||
|
|
bbb5ba0adf | ||
|
|
844bfcf484 | ||
|
|
13441f0d98 | ||
|
|
41b377211c | ||
|
|
9ceb68ee90 | ||
|
|
d1baa5a82f | ||
|
|
04da4dd3e3 | ||
|
|
b0e2fcdba7 | ||
|
|
1e2a757cd3 | ||
|
|
ab842ba6ae | ||
|
|
94a7221a94 | ||
|
|
00dadcace0 | ||
|
|
81a2f2958d | ||
|
|
68b4fb60ad | ||
|
|
96b2b62377 | ||
|
|
b5c98bbfd3 | ||
|
|
58e12cf2e8 | ||
|
|
d8b5fae622 | ||
|
|
67ac81d728 | ||
|
|
b5f1ea3140 | ||
|
|
4d854a1513 | ||
|
|
87da655eab | ||
|
|
a8fda9c61a | ||
|
|
55505ff817 | ||
|
|
20d31ab8e0 | ||
|
|
e5b83aab5e | ||
|
|
a9d5f62304 | ||
|
|
72e1ed7058 | ||
|
|
d8e67a2609 | ||
|
|
50e12376de | ||
|
|
73aa6c25f3 |
12
.github/workflows/build-docker-images.yml
vendored
12
.github/workflows/build-docker-images.yml
vendored
@@ -8,8 +8,6 @@ on:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -27,14 +25,11 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -65,14 +60,11 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -97,13 +89,9 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
2
.github/workflows/nightly-tests.yml
vendored
2
.github/workflows/nightly-tests.yml
vendored
@@ -7,8 +7,6 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
# env:
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
jobs:
|
||||
|
||||
161
.github/workflows/pr_style_bot.yml
vendored
161
.github/workflows/pr_style_bot.yml
vendored
@@ -1,161 +0,0 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
|
||||
name: PR Style Bot
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
check-permissions:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
|
||||
steps:
|
||||
- name: Check user permission
|
||||
id: check_user_permission
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const comment_user = context.payload.comment.user.login;
|
||||
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username: comment_user
|
||||
});
|
||||
|
||||
const authorized =
|
||||
permission.permission === 'admin' ||
|
||||
permission.permission === 'write';
|
||||
|
||||
console.log(
|
||||
`User ${comment_user} has permission level: ${permission.permission}, ` +
|
||||
`authorized: ${authorized} (admins & maintainers allowed)`
|
||||
);
|
||||
|
||||
core.setOutput('has_permission', authorized);
|
||||
|
||||
run-style-bot:
|
||||
needs: check-permissions
|
||||
if: needs.check-permissions.outputs.is_authorized == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v4
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
persist-credentials: true
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${PRNUMBER}"
|
||||
echo "Head Ref: ${HEADREF}"
|
||||
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Get Ruff Version from pre-commit-config.yaml
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Ruff
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check --fix
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local lfs.https://github.com/.locksverify false
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${HEADREF}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
||||
60
.github/workflows/quality.yml
vendored
60
.github/workflows/quality.yml
vendored
@@ -4,12 +4,12 @@ on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -19,9 +19,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
@@ -32,27 +30,55 @@ jobs:
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Ruff
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check --output-format=github
|
||||
run: ruff check
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format --diff
|
||||
|
||||
typos:
|
||||
name: Typos
|
||||
|
||||
poetry_check:
|
||||
name: Poetry check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
|
||||
17
.github/workflows/test-docker-build.yml
vendored
17
.github/workflows/test-docker-build.yml
vendored
@@ -4,12 +4,12 @@ name: Test Dockerfiles
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
# Run only when DockerFile files are modified
|
||||
- "docker/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -22,8 +22,6 @@ jobs:
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
@@ -32,18 +30,21 @@ jobs:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
- name: Run step if only the files listed above change # zizmor: ignore[template-injection]
|
||||
- name: Run step if only the files listed above change
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: set-matrix
|
||||
env:
|
||||
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
run: |
|
||||
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: needs.get_changed_files.outputs.matrix != ''
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -51,13 +52,9 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
121
.github/workflows/test.yml
vendored
121
.github/workflows/test.yml
vendored
@@ -2,13 +2,14 @@ name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
push:
|
||||
@@ -19,16 +20,10 @@ on:
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_VERSION: "0.6.0"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Pytest
|
||||
@@ -39,7 +34,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
@@ -47,19 +41,25 @@ jobs:
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install lerobot (all extras)
|
||||
run: uv sync --all-extras
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
@@ -74,63 +74,66 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install lerobot
|
||||
run: uv sync --extra "test"
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --extras "test"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||
# end-to-end:
|
||||
# name: End-to-end
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# MUJOCO_GL: egl
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
# - name: Install apt dependencies
|
||||
# # portaudio19-dev is needed to install pyaudio
|
||||
# run: |
|
||||
# sudo apt-get update && \
|
||||
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
# - name: Install poetry
|
||||
# run: |
|
||||
# pipx install poetry && poetry config virtualenvs.in-project true
|
||||
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install lerobot (all extras)
|
||||
run: |
|
||||
uv venv
|
||||
uv sync --all-extras
|
||||
# - name: Set up Python 3.10
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: "3.10"
|
||||
# cache: "poetry"
|
||||
|
||||
- name: venv
|
||||
run: |
|
||||
echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
|
||||
# - name: Install poetry dependencies
|
||||
# run: |
|
||||
# poetry install --all-extras
|
||||
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
make test-end-to-end \
|
||||
&& rm -rf outputs
|
||||
# - name: Test end-to-end
|
||||
# run: |
|
||||
# make test-end-to-end \
|
||||
# && rm -rf outputs
|
||||
|
||||
5
.github/workflows/trufflehog.yml
vendored
5
.github/workflows/trufflehog.yml
vendored
@@ -3,7 +3,8 @@ on:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions: {}
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
@@ -13,8 +14,6 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -49,10 +49,6 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# uv/poetry lock files
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
exclude: ^(tests/data)
|
||||
default_language_version:
|
||||
python: python3.12
|
||||
python: python3.10
|
||||
repos:
|
||||
##### Style / Misc. #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
@@ -14,34 +13,25 @@ repos:
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.1
|
||||
rev: v3.19.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.9
|
||||
rev: v0.8.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.0
|
||||
rev: v8.21.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
hooks:
|
||||
- id: zizmor
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-c", "pyproject.toml"]
|
||||
additional_dependencies: ["bandit[toml]"]
|
||||
|
||||
@@ -129,71 +129,38 @@ Follow these steps to start contributing:
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
```bash
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
poetry install --sync --extras "dev test"
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
poetry install --sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry lock
|
||||
poetry lock --no-update
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
@@ -228,7 +195,7 @@ Follow these steps to start contributing:
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
236
Makefile
236
Makefile
@@ -2,10 +2,10 @@
|
||||
|
||||
PYTHON_PATH := $(shell which python)
|
||||
|
||||
# If uv is installed and a virtual environment exists, use it
|
||||
UV_CHECK := $(shell command -v uv)
|
||||
ifneq ($(UV_CHECK),)
|
||||
PYTHON_PATH := $(shell .venv/bin/python)
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
@@ -20,109 +20,171 @@ build-gpu:
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-resume
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
--save_checkpoint=true \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--device=$(DEVICE) \
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--env.type=aloha \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE)
|
||||
-p tests/outputs/act/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-ete-train-amp:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
use_amp=true
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
--policy.num_inference_steps=10 \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--device=$(DEVICE) \
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
policy=diffusion \
|
||||
policy.down_dims=\[64,128,256\] \
|
||||
policy.diffusion_step_embed_dim=32 \
|
||||
policy.num_inference_steps=10 \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE)
|
||||
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=tdmpc \
|
||||
--env.type=xarm \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2 \
|
||||
--log_freq=1 \
|
||||
--wandb.enable=false \
|
||||
--device=$(DEVICE) \
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
env.task=XarmLift-v0 \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--env.type=xarm \
|
||||
--env.episode_length=5 \
|
||||
--env.task=XarmLift-v0 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE)
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
python lerobot/scripts/train.py \
|
||||
policy=created_by_Makefile.yaml \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=$(DEVICE) \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
|
||||
80
README.md
80
README.md
@@ -122,7 +122,10 @@ wandb login
|
||||
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||
| └── advanced # contains even more examples for those who have mastered the basics
|
||||
├── lerobot
|
||||
| ├── configs # contains config classes with all options that you can override in the command line
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
@@ -210,7 +213,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
@@ -219,48 +222,87 @@ Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrat
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
-p lerobot/diffusion_pusht \
|
||||
eval.n_episodes=10 \
|
||||
eval.batch_size=10
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
|
||||
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
env.task=AlohaInsertion-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||
```
|
||||
|
||||
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
|
||||
|
||||
```bash
|
||||
checkpoints
|
||||
├── 000250 # checkpoint_dir for training step 250
|
||||
│ ├── pretrained_model # Hugging Face pretrained model dir
|
||||
│ │ ├── config.json # Hugging Face pretrained model config
|
||||
│ │ ├── config.yaml # consolidated Hydra config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ └── README.md # Hugging Face model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
```
|
||||
|
||||
To resume training from a checkpoint, you can add these to the `train.py` python command:
|
||||
```bash
|
||||
hydra.run.dir=your/original/experiment/dir resume=true
|
||||
```
|
||||
|
||||
It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
```bash
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
<!-- ### Add a new dataset
|
||||
### Add a new dataset
|
||||
|
||||
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -278,7 +320,7 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py). -->
|
||||
If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
|
||||
|
||||
|
||||
### Add a pretrained policy
|
||||
@@ -288,7 +330,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
|
||||
@@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
|
||||
@@ -1,29 +1,32 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
speech-dispatcher \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure environment variables
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
|
||||
@@ -1,91 +1,63 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Source the parts](#a-source-the-parts)
|
||||
- [B. Install LeRobot](#b-install-lerobot)
|
||||
- [C. Configure the motors](#c-configure-the-motors)
|
||||
- [D. Assemble the arms](#d-assemble-the-arms)
|
||||
- [E. Calibrate](#e-calibrate)
|
||||
- [F. Teleoperate](#f-teleoperate)
|
||||
- [G. Record a dataset](#g-record-a-dataset)
|
||||
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||
- [I. Replay an episode](#i-replay-an-episode)
|
||||
- [J. Train a policy](#j-train-a-policy)
|
||||
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||
- [L. More Information](#l-more-information)
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts,
|
||||
and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
On your computer:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda activate lerobot
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||
## C. Configure the motors
|
||||
|
||||
> [!NOTE]
|
||||
> Throughout this tutorial you will find videos on how to do the steps, the full video tutorial can be found here: [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I).
|
||||
## C. Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6 (F1...F6 and L1...L6).
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find port
|
||||
#### a. Run the script to find ports
|
||||
|
||||
<details>
|
||||
<summary><strong>Video finding port</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||
</details>
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
@@ -124,66 +96,14 @@ sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update config file
|
||||
#### d. Update YAML file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so100"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Assembling the Base
|
||||
Let's begin with assembling the follower arm base
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
|
||||
<details>
|
||||
<summary><strong>Video configuring motor</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/ef9b3317-2e11-4858-b9d3-f0a02fb48ecf"></video>
|
||||
<video src="https://github.com/user-attachments/assets/f36b5ed5-c803-4ebe-8947-b39278776a0d"></video>
|
||||
</details>
|
||||
|
||||
Plug your first motor F1 and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate. Replace the text after --port to the corresponding follower control board port and run this command in cmd:
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
@@ -193,8 +113,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -211,47 +130,22 @@ Redo the process for all your motors until ID 6. Do the same for the 6 motors of
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
<details>
|
||||
<summary><strong>Video removing gears</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/0c95b88c-5b85-413d-ba19-aee2f864f2a7"></video>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
|
||||
<details>
|
||||
<summary><strong>Video adding motor horn</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/ef3391a4-ad05-4100-b2bd-1699bf86c969"></video>
|
||||
|
||||
</details>
|
||||
|
||||
Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## D. Assemble the arms
|
||||
|
||||
<details>
|
||||
<summary><strong>Video assembling arms</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/488a39de-0189-4461-9de3-05b015f90cca"></video>
|
||||
|
||||
</details>
|
||||
|
||||
Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## E. Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
@@ -261,11 +155,9 @@ You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
@@ -277,11 +169,9 @@ Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
@@ -289,19 +179,18 @@ python lerobot/scripts/control_robot.py \
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
@@ -321,46 +210,40 @@ echo $HF_USER
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/so100_test \
|
||||
--control.tags='["so100","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--tags so100 tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--local-files-only 1
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/so100_test \
|
||||
--control.episode=0
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
@@ -368,18 +251,20 @@ python lerobot/scripts/control_robot.py \
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/so100_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so100_test \
|
||||
--job_name=act_so100_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
dataset_repo_id=${HF_USER}/so100_test \
|
||||
policy=act_so100_real \
|
||||
env=so100_real \
|
||||
hydra.run.dir=outputs/train/act_so100_test \
|
||||
hydra.job.name=act_so100_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy=act_so100_real`. This loads configurations from [`lerobot/configs/policy/act_so100_real.yaml`](../lerobot/configs/policy/act_so100_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`.
|
||||
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
@@ -388,28 +273,24 @@ Training should take several hours. You will find checkpoints in `outputs/train/
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_so100_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_so100_test/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_act_so100_test \
|
||||
--tags so100 tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_so100_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`).
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
|
||||
@@ -1,463 +0,0 @@
|
||||
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Source the parts](#a-source-the-parts)
|
||||
- [B. Install software Pi](#b-install-software-on-pi)
|
||||
- [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop)
|
||||
- [D. Assemble the arms](#d-assembly)
|
||||
- [E. Calibrate](#e-calibration)
|
||||
- [F. Teleoperate](#f-teleoperate)
|
||||
- [G. Record a dataset](#g-record-a-dataset)
|
||||
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||
- [I. Replay an episode](#i-replay-an-episode)
|
||||
- [J. Train a policy](#j-train-a-policy)
|
||||
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371).
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install software on Pi
|
||||
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
|
||||
|
||||
### Install OS
|
||||
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
|
||||
|
||||
### Setup SSH
|
||||
After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
On your Raspberry Pi:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## C. Install LeRobot on laptop
|
||||
If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi.
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
On your computer:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||
|
||||
# D. Assembly
|
||||
|
||||
First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base.
|
||||
|
||||
## SO100 Arms
|
||||
### Configure motors
|
||||
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||
|
||||
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
|
||||
|
||||
### Assemble arms
|
||||
[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms)
|
||||
|
||||
## Mobile base (LeKiwi)
|
||||
[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi)
|
||||
|
||||
### Update config
|
||||
Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script.
|
||||
|
||||
#### a. Run the script to find port
|
||||
|
||||
<details>
|
||||
<summary><strong>Video finding port</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||
</details>
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "172.17.133.91"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480),
|
||||
"mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
```
|
||||
|
||||
# E. Calibration
|
||||
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
|
||||
|
||||
|
||||
### Calibrate follower arm (on mobile base)
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
### Calibrate leader arm
|
||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script (on your laptop/pc) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
# F. Teleoperate
|
||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=remote_robot
|
||||
```
|
||||
|
||||
Then on your laptop, also run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=teleoperate \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
|------------|-------------------|-----------------------|
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|
||||
|
||||
| Key | Action |
|
||||
|------|--------------------------------|
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
|
||||
> [!TIP]
|
||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||
|
||||
## Troubleshoot communication
|
||||
|
||||
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
|
||||
|
||||
### 1. Verify IP Address Configuration
|
||||
Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line):
|
||||
```bash
|
||||
hostname -I
|
||||
```
|
||||
|
||||
### 2. Check if Pi is reachable from laptop/pc
|
||||
Try pinging the Raspberry Pi from your laptop:
|
||||
```bach
|
||||
ping <your_pi_ip_address>
|
||||
```
|
||||
|
||||
If the ping fails:
|
||||
- Ensure the Pi is powered on and connected to the same network.
|
||||
- Check if SSH is enabled on the Pi.
|
||||
|
||||
### 3. Try SSH connection
|
||||
If you can't SSH into the Pi, it might not be properly connected. Use:
|
||||
```bash
|
||||
ssh <your_pi_user_name>@<your_pi_ip_address>
|
||||
```
|
||||
If you get a connection error:
|
||||
- Ensure SSH is enabled on the Pi by running:
|
||||
```bash
|
||||
sudo raspi-config
|
||||
```
|
||||
Then navigate to: **Interfacing Options -> SSH** and enable it.
|
||||
|
||||
### 4. Same config file
|
||||
Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same.
|
||||
|
||||
# G. Record a dataset
|
||||
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
# H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/lekiwi_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/lekiwi_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
# I. Replay an episode
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/lekiwi_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_lekiwi_test \
|
||||
--job_name=act_lekiwi_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Drive to the red block and pick it up" \
|
||||
--control.repo_id=${HF_USER}/eval_act_lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`).
|
||||
@@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
@@ -83,54 +83,6 @@ sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`MossRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("moss")
|
||||
@dataclass
|
||||
class MossRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/moss"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
@@ -182,11 +134,9 @@ You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
@@ -198,11 +148,9 @@ Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMi
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
@@ -210,19 +158,18 @@ python lerobot/scripts/control_robot.py \
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
@@ -242,46 +189,40 @@ echo $HF_USER
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/moss_test \
|
||||
--control.tags='["moss","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--tags moss tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--local-files-only 1
|
||||
--repo-id ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/moss_test \
|
||||
--control.episode=0
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
@@ -289,18 +230,20 @@ python lerobot/scripts/control_robot.py \
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/moss_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_moss_test \
|
||||
--job_name=act_moss_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
dataset_repo_id=${HF_USER}/moss_test \
|
||||
policy=act_moss_real \
|
||||
env=moss_real \
|
||||
hydra.run.dir=outputs/train/act_moss_test \
|
||||
hydra.job.name=act_moss_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy=act_moss_real`. This loads configurations from [`lerobot/configs/policy/act_moss_real.yaml`](../lerobot/configs/policy/act_moss_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`.
|
||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
@@ -309,24 +252,21 @@ Training should take several hours. You will find checkpoints in `outputs/train/
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_moss_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_moss_test/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||
--tags moss tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_moss_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_moss_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_moss_test`).
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_moss_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_moss_test`).
|
||||
|
||||
## More
|
||||
|
||||
|
||||
83
examples/12_train_hilserl_classifier.md
Normal file
83
examples/12_train_hilserl_classifier.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
@@ -1,11 +1,6 @@
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
```bash
|
||||
pip install -e ".[pusht]"`
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -15,6 +10,7 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
@@ -22,15 +18,25 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = "cuda"
|
||||
|
||||
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR a path to a local outputs/train folder.
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
# an image of the scene and state/position of the agent. The environment
|
||||
@@ -41,17 +47,7 @@ env = gym.make(
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
# We can verify that the shapes of the features expected by the policy match the ones from the observations
|
||||
# produced by the environment
|
||||
print(policy.config.input_features)
|
||||
print(env.observation_space)
|
||||
|
||||
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
|
||||
# environment
|
||||
print(policy.config.output_features)
|
||||
print(env.action_space)
|
||||
|
||||
# Reset the policy and environments to prepare for rollout
|
||||
# Reset the policy and environmens to prepare for rollout
|
||||
policy.reset()
|
||||
numpy_observation, info = env.reset(seed=42)
|
||||
|
||||
|
||||
@@ -8,99 +8,72 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
device = torch.device("cuda")
|
||||
log_freq = 250
|
||||
|
||||
# # Select your device
|
||||
device = torch.device("cuda")
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
log_freq = 1
|
||||
# Set up the the policy.
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
||||
# creating the policy:
|
||||
# - input/output shapes: to properly size the policy
|
||||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
# Then we create our optimizer and dataloader for offline training.
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
@@ -1,223 +1,193 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
- (Optional) Instantiates a simulation environment corresponding to that dataset.
|
||||
- Instantiates a policy.
|
||||
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
|
||||
- Makes a simulation environment.
|
||||
- Makes a dataset corresponding to that simulation environment.
|
||||
- Makes a policy.
|
||||
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
|
||||
|
||||
## Overview of the configuration system
|
||||
## Basics of how we use Hydra
|
||||
|
||||
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
|
||||
|
||||
First, `lerobot/configs` has a directory structure like this:
|
||||
|
||||
```
|
||||
.
|
||||
├── default.yaml
|
||||
├── env
|
||||
│ ├── aloha.yaml
|
||||
│ ├── pusht.yaml
|
||||
│ └── xarm.yaml
|
||||
└── policy
|
||||
├── act.yaml
|
||||
├── diffusion.yaml
|
||||
└── tdmpc.yaml
|
||||
```
|
||||
|
||||
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
|
||||
|
||||
When you run the training script with
|
||||
|
||||
In the training script, the main function `train` expects a `TrainPipelineConfig` object:
|
||||
```python
|
||||
# train.py
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
|
||||
```python
|
||||
@dataclass
|
||||
class TrainPipelineConfig:
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
```
|
||||
in which `DatasetConfig` for example is defined as such:
|
||||
```python
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
repo_id: str
|
||||
episodes: list[int] | None = None
|
||||
video_backend: str = "pyav"
|
||||
```yaml
|
||||
defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
```
|
||||
|
||||
This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
|
||||
From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`.
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
|
||||
By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
|
||||
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
|
||||
|
||||
## Specifying values from the CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||
```
|
||||
|
||||
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
There are two things to note here:
|
||||
- Config overrides are passed as `param_name=param_value`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||
|
||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||
|
||||
## Overriding configuration parameters in the CLI
|
||||
|
||||
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
task: AlohaInsertion-v0
|
||||
```
|
||||
|
||||
And if you look in `policy/act.yaml` you will see something like:
|
||||
|
||||
```yaml
|
||||
# lerobot/configs/policy/act.yaml
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
```
|
||||
|
||||
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
|
||||
|
||||
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/env/aloha.yaml
|
||||
env:
|
||||
- task: AlohaInsertion-v0
|
||||
+ task: AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
Then, we'd also need to switch to using the cube transfer dataset.
|
||||
|
||||
```diff
|
||||
# lerobot/configs/policy/act.yaml
|
||||
-dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
|
||||
```
|
||||
|
||||
Then, you'd be able to run:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=aloha
|
||||
```
|
||||
|
||||
and you'd be training and evaluating on the cube transfer task.
|
||||
|
||||
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0
|
||||
```
|
||||
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
|
||||
|
||||
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--output_dir=outputs/train/act_aloha_insertion
|
||||
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
|
||||
device=cuda
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
policy=act \
|
||||
training.eval_freq=10000 \
|
||||
training.log_freq=250 \
|
||||
training.offline_steps=100000 \
|
||||
training.save_model=true \
|
||||
training.save_freq=25000 \
|
||||
eval.n_episodes=50 \
|
||||
eval.batch_size=50 \
|
||||
wandb.enable=false \
|
||||
```
|
||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
|
||||
|
||||
## Using a configuration file not in `lerobot/configs`
|
||||
|
||||
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--output_dir=outputs/train/act_aloha_transfer
|
||||
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
|
||||
```
|
||||
|
||||
## Loading from a config file
|
||||
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
|
||||
|
||||
Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used:
|
||||
```json
|
||||
{
|
||||
"dataset": {
|
||||
"repo_id": "lerobot/aloha_sim_transfer_cube_human",
|
||||
"episodes": null,
|
||||
...
|
||||
},
|
||||
"env": {
|
||||
"type": "aloha",
|
||||
"task": "AlohaTransferCube-v0",
|
||||
"fps": 50,
|
||||
...
|
||||
},
|
||||
"policy": {
|
||||
"type": "act",
|
||||
"n_obs_steps": 1,
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
|
||||
|
||||
We can then simply load the config values from this file using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly.
|
||||
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
```
|
||||
> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next.
|
||||
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
|
||||
|
||||
## Resume training
|
||||
|
||||
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here.
|
||||
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaTransferCube-v0 \
|
||||
--log_freq=25 \
|
||||
--save_freq=100 \
|
||||
--output_dir=outputs/train/run_resumption
|
||||
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
|
||||
```
|
||||
|
||||
Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal:
|
||||
```
|
||||
INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
```
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
|
||||
You could double the number of steps of the previous run with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
## Outputs of a run
|
||||
In the output directory, there will be a folder called `checkpoints` with the following structure:
|
||||
```bash
|
||||
outputs/train/run_resumption/checkpoints
|
||||
├── 000100 # checkpoint_dir for training step 100
|
||||
│ ├── pretrained_model/
|
||||
│ │ ├── config.json # policy config
|
||||
│ │ ├── model.safetensors # policy weights
|
||||
│ │ └── train_config.json # train config
|
||||
│ └── training_state/
|
||||
│ ├── optimizer_param_groups.json # optimizer param groups
|
||||
│ ├── optimizer_state.safetensors # optimizer state
|
||||
│ ├── rng_state.safetensors # rng states
|
||||
│ ├── scheduler_state.json # scheduler state
|
||||
│ └── training_step.json # training step
|
||||
├── 000200
|
||||
└── last -> 000200 # symlink to the last available checkpoint
|
||||
```
|
||||
|
||||
## Fine-tuning a pre-trained policy
|
||||
|
||||
In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub.
|
||||
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy.
|
||||
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
|
||||
|
||||
## Typical logs and metrics
|
||||
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint.
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
|
||||
```
|
||||
or evaluation log:
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
|
||||
```
|
||||
|
||||
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
|
||||
|
||||
- `smpl`: number of samples seen during training.
|
||||
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
|
||||
- `epch`: number of time all unique samples are seen (epoch).
|
||||
@@ -230,45 +200,14 @@ These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here
|
||||
|
||||
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
|
||||
|
||||
## In short
|
||||
|
||||
We'll summarize here the main use cases to remember from this tutorial.
|
||||
|
||||
#### Train a policy from scratch – CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
```
|
||||
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
|
||||
#### Resume/continue a training run
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
```
|
||||
|
||||
#### Fine-tuning
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
--env.task=AlohaInsertion-v0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task?
|
||||
If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md).
|
||||
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
Or in the meantime, happy training! 🤗
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
|
||||
|
||||
Or in the meantime, happy coding! 🤗
|
||||
|
||||
37
examples/5_resume_training.md
Normal file
37
examples/5_resume_training.md
Normal file
@@ -0,0 +1,37 @@
|
||||
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
|
||||
|
||||
## Basic training resumption
|
||||
|
||||
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
policy=act \
|
||||
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
env=aloha \
|
||||
env.task=AlohaTransferCube-v0 \
|
||||
training.log_freq=25 \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=100
|
||||
```
|
||||
|
||||
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
|
||||
|
||||
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/run_resumption \
|
||||
resume=true
|
||||
```
|
||||
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
|
||||
|
||||
---
|
||||
|
||||
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -36,14 +36,9 @@ Using `pip`:
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Using `poetry`:
|
||||
Or using `poetry`:
|
||||
```bash
|
||||
poetry sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
Using `uv`:
|
||||
```bash
|
||||
uv sync --extra "dynamixel"
|
||||
poetry install --sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
@@ -59,53 +54,24 @@ Then plug the 12V power supply to the motor bus of the follower arm. It has two
|
||||
|
||||
Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer.
|
||||
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
*Copy pasting python code*
|
||||
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
In the upcoming sections, you'll learn about our classes and functions by running some python code, in an interactive session, or by copy-pasting it in a python file. If this is your first time using the tutorial., we highly recommend going through these steps to get deeper intuition about how things work. Once you're more familiar, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides '~cameras' # do not instantiate the cameras
|
||||
```
|
||||
|
||||
It will automatically:
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
1. Detect and help you correct any motor configuration issues.
|
||||
2. Identify any missing calibrations and initiate the calibration procedure.
|
||||
3. Connect the robot and start teleoperation.
|
||||
|
||||
### a. Control your motors with DynamixelMotorsBus
|
||||
|
||||
You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors.
|
||||
|
||||
**First Configuration of your motors**
|
||||
|
||||
You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors.
|
||||
|
||||
Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your first motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6.
|
||||
|
||||
The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_
|
||||
|
||||
After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video.
|
||||
|
||||
**Instantiate the DynamixelMotorsBus**
|
||||
|
||||
To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`).
|
||||
@@ -139,10 +105,10 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0032081
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0031751
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
*Listing and Configuring Motors*
|
||||
@@ -151,11 +117,13 @@ Next, you'll need to list the motors for each arm, including their name, index,
|
||||
|
||||
To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier:
|
||||
```python
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
leader_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
leader_port = "/dev/tty.usbmodem575E0031751"
|
||||
follower_port = "/dev/tty.usbmodem575E0032081"
|
||||
|
||||
leader_arm = DynamixelMotorsBus(
|
||||
port=leader_port,
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
@@ -167,8 +135,8 @@ leader_config = DynamixelMotorsBusConfig(
|
||||
},
|
||||
)
|
||||
|
||||
follower_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
follower_arm = DynamixelMotorsBus(
|
||||
port=follower_port,
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
@@ -179,57 +147,45 @@ follower_config = DynamixelMotorsBusConfig(
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
)
|
||||
|
||||
leader_arm = DynamixelMotorsBus(leader_config)
|
||||
follower_arm = DynamixelMotorsBus(follower_config)
|
||||
```
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
*Updating the YAML Configuration File*
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
Next, update the port values in the YAML configuration file for the Koch robot at [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the ports you've identified:
|
||||
```yaml
|
||||
[...]
|
||||
robot_type: koch
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081 # <- Update
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
[...]
|
||||
```
|
||||
|
||||
Don't forget to set `robot_type: aloha` if you follow this tutorial with [Aloha bimanual robot](aloha-2.github.io) instead of Koch v1.1
|
||||
|
||||
This configuration file is used to instantiate your robot across all scripts. We'll cover how this works later on.
|
||||
|
||||
**Connect and Configure your Motors**
|
||||
|
||||
Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus.
|
||||
@@ -356,27 +312,27 @@ Alternatively, you can unplug the power cord, which will automatically disable t
|
||||
|
||||
**Instantiate the ManipulatorRobot**
|
||||
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`.
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_arm` and `follower_arm`.
|
||||
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms.
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms.
|
||||
|
||||
You also need to provide a path to a calibration directory, such as `calibration_dir=".cache/calibration/koch"`. More on this in the next section.
|
||||
|
||||
Run the following code to instantiate your manipulator robot:
|
||||
```python
|
||||
from lerobot.common.robot_devices.robots.configs import KochRobotConfig
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
robot_config = KochRobotConfig(
|
||||
leader_arms={"main": leader_config},
|
||||
follower_arms={"main": follower_config},
|
||||
cameras={}, # We don't use any camera for now
|
||||
robot = ManipulatorRobot(
|
||||
robot_type="koch",
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
)
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
The `robot_type="koch"` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `robot_type="aloha"` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected. If you need to run manual calibration, simply update `calibration_dir` to `.cache/calibration/aloha`.
|
||||
|
||||
**Calibrate and Connect the ManipulatorRobot**
|
||||
|
||||
@@ -398,7 +354,7 @@ And here are the corresponding positions for the leader arm:
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
@@ -623,11 +579,9 @@ Note: Some cameras may take a few seconds to warm up, and the first frame might
|
||||
|
||||
Finally, run this code to instantiate and connectyour camera:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
|
||||
@@ -649,7 +603,7 @@ uint8
|
||||
|
||||
With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance:
|
||||
```python
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480)
|
||||
camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480)
|
||||
```
|
||||
|
||||
If the provided arguments are not compatible with the camera, an exception will be raised.
|
||||
@@ -663,20 +617,18 @@ camera.disconnect()
|
||||
|
||||
**Instantiate your robot with cameras**
|
||||
|
||||
Additionally, you can set up your robot to work with your cameras.
|
||||
Additionaly, you can set up your robot to work with your cameras.
|
||||
|
||||
Modify the following Python code with the appropriate camera names and configurations:
|
||||
```python
|
||||
robot = ManipulatorRobot(
|
||||
KochRobotConfig(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
robot.connect()
|
||||
```
|
||||
@@ -700,20 +652,39 @@ torch.Size([3, 480, 640])
|
||||
255
|
||||
```
|
||||
|
||||
### d. Use `control_robot.py` and our `teleoperate` function
|
||||
Also, update the following lines of the yaml file for Koch robot [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the names and configurations of your cameras:
|
||||
```yaml
|
||||
[...]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
```
|
||||
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next.
|
||||
This file is used to instantiate your robot in all our scripts. We will explain how this works in the next section.
|
||||
|
||||
### d. Use `koch.yaml` and our `teleoperate` function
|
||||
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the path to the robot yaml file (e.g. [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml)) and control your robot with various modes as explained next.
|
||||
|
||||
Try running this code to teleoperate your robot (if you dont have a camera, keep reading):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
|
||||
```
|
||||
|
||||
It contains
|
||||
@@ -723,12 +694,21 @@ It contains
|
||||
- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`.
|
||||
- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading.
|
||||
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`:
|
||||
Note: you can override any entry in the yaml file using `--robot-overrides` and the [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax. If needed, you can override the ports like this:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
leader_arms.main.port=/dev/tty.usbmodem575E0031751 \
|
||||
follower_arms.main.port=/dev/tty.usbmodem575E0032081
|
||||
```
|
||||
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [hydra.cc](https://hydra.cc/docs/advanced/override_grammar/basic) syntax `'~cameras'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--robot-overrides \
|
||||
'~cameras'
|
||||
```
|
||||
|
||||
We advise to create a new yaml file when the command becomes too long.
|
||||
@@ -764,23 +744,23 @@ for _ in range(record_time_s * fps):
|
||||
|
||||
Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section.
|
||||
|
||||
### a. Use the `record` function
|
||||
### a. Use `koch.yaml` and the `record` function
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities:
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of recording.
|
||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--push-to-hub 0` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again. You can also use `--force-override 1` to start recording from scratch.
|
||||
5. Set the flow of data recording using command line arguments:
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--control.num_episodes=50` defines the number of episodes to record (50 by default).
|
||||
- `--warmup-time-s` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--episode-time-s` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--reset-time-s` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--num-episodes` defines the number of episodes to record (50 by default).
|
||||
6. Control the flow during data recording using keyboard keys:
|
||||
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
7. Similarly to `teleoperate`, you can also use the command line to override anything.
|
||||
7. Similarly to `teleoperate`, you can also use `--robot-path` and `--robot-overrides` to specify your robots.
|
||||
|
||||
Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -791,29 +771,27 @@ Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `l
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
If you don't want to push to hub, use `--control.push_to_hub=false`.
|
||||
If you don't want to push to hub, use `--push-to-hub 0`.
|
||||
|
||||
Now run this to record 2 episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--tags tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 2
|
||||
```
|
||||
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
@@ -825,8 +803,8 @@ It contains:
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
||||
@@ -846,7 +824,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
|
||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
### b. Advice for recording dataset
|
||||
### b. Advices for recording dataset
|
||||
|
||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||
|
||||
@@ -864,8 +842,6 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%">
|
||||
@@ -877,12 +853,11 @@ A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/cont
|
||||
|
||||
To replay the first episode of the dataset you just recorded, run the following command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.episode=0
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
@@ -894,17 +869,50 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/koch_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
dataset_repo_id=${HF_USER}/koch_test \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
hydra.run.dir=outputs/train/act_koch_test \
|
||||
hydra.job.name=act_koch_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy=act_koch_real`. This loads configurations from [`lerobot/configs/policy/act_koch_real.yaml`](../lerobot/configs/policy/act_koch_real.yaml). Importantly, this policy uses 2 cameras as input `laptop` and `phone`. If your dataset has different cameras, update the yaml file to account for it in the following parts:
|
||||
```yaml
|
||||
...
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
...
|
||||
input_shapes:
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
...
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
...
|
||||
```
|
||||
3. We provided an environment as argument with `env=koch_real`. This loads configurations from [`lerobot/configs/env/koch_real.yaml`](../lerobot/configs/env/koch_real.yaml). It looks like
|
||||
```yaml
|
||||
fps: 30
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
```
|
||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
@@ -970,36 +978,34 @@ for _ in range(inference_time_s * fps):
|
||||
busy_wait(1 / fps - dt_s)
|
||||
```
|
||||
|
||||
### a. Use our `record` function
|
||||
### a. Use `koch.yaml` and our `record` function
|
||||
|
||||
Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation.
|
||||
|
||||
To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/eval_act_koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_koch_test \
|
||||
--tags tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 30 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`).
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_koch_test`).
|
||||
|
||||
### b. Visualize evaluation afterwards
|
||||
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/eval_act_koch_test
|
||||
--repo-id ${HF_USER}/eval_koch_test
|
||||
```
|
||||
|
||||
## 6. Next step
|
||||
|
||||
@@ -92,22 +92,20 @@ Serial Number = stretch-se3-3054
|
||||
**Calibrate (Optional)**
|
||||
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=calibrate
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml
|
||||
```
|
||||
This is equivalent to running `stretch_robot_home.py`
|
||||
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first.
|
||||
|
||||
**Teleoperate**
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml
|
||||
```
|
||||
This is essentially the same as running `stretch_gamepad_teleop.py`
|
||||
|
||||
@@ -127,18 +125,16 @@ echo $HF_USER
|
||||
|
||||
Record one episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/stretch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--tags stretch tutorial \
|
||||
--warmup-time-s 3 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 1 \
|
||||
--push-to-hub 0
|
||||
```
|
||||
|
||||
> **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup).
|
||||
@@ -146,12 +142,11 @@ python lerobot/scripts/control_robot.py \
|
||||
**Replay an episode**
|
||||
Now try to replay this episode (make sure the robot's initial position is the same):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/stretch_test \
|
||||
--control.episode=0
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch.
|
||||
|
||||
@@ -2,7 +2,7 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
@@ -51,18 +51,16 @@ Teleoperation consists in manually operating the leader arms to move the followe
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=5 \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=5
|
||||
```
|
||||
|
||||
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=teleoperate
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
@@ -82,29 +80,27 @@ echo $HF_USER
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/aloha_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--tags aloha tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
@@ -113,17 +109,16 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
## Replay an episode
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
|
||||
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot-overrides max_relative_target=5` to your command line as explained above.
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/aloha_test \
|
||||
--control.episode=0
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
@@ -131,51 +126,49 @@ python lerobot/scripts/control_robot.py \
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/aloha_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
--job_name=act_aloha_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
dataset_repo_id=${HF_USER}/aloha_test \
|
||||
policy=act_aloha_real \
|
||||
env=aloha_real \
|
||||
hydra.run.dir=outputs/train/act_aloha_test \
|
||||
hydra.job.name=act_aloha_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy=act_aloha_real`. This loads configurations from [`lerobot/configs/policy/act_aloha_real.yaml`](../lerobot/configs/policy/act_aloha_real.yaml). Importantly, this policy uses 4 cameras as input `cam_right_wrist`, `cam_left_wrist`, `cam_high`, and `cam_low`.
|
||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_aloha_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \
|
||||
--control.num_image_writer_processes=1
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||
--tags aloha tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
--num-image-writer-processes 1 \
|
||||
-p outputs/train/act_aloha_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--num-image-writer-processes 1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--num-image-writer-processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
||||
|
||||
87
examples/advanced/1_train_act_pusht/act_pusht.yaml
Normal file
87
examples/advanced/1_train_act_pusht/act_pusht.yaml
Normal file
@@ -0,0 +1,87 @@
|
||||
# @package _global_
|
||||
|
||||
# Change the seed to match what PushT eval uses
|
||||
# (to avoid evaluating on seeds used for generating the training data).
|
||||
seed: 100000
|
||||
# Change the dataset repository to the PushT one.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
# Use min_max normalization just because it's more standard.
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
# Use min_max normalization just because it's more standard.
|
||||
action: min_max
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
70
examples/advanced/1_train_act_pusht/train_act_pusht.md
Normal file
70
examples/advanced/1_train_act_pusht/train_act_pusht.md
Normal file
@@ -0,0 +1,70 @@
|
||||
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
|
||||
|
||||
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
|
||||
|
||||
Let's get started!
|
||||
|
||||
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||
```
|
||||
|
||||
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
|
||||
|
||||
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
|
||||
|
||||
```diff
|
||||
override_dataset_stats:
|
||||
- observation.images.top:
|
||||
+ observation.image:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
policy:
|
||||
input_shapes:
|
||||
- observation.images.top: [3, 480, 640]
|
||||
+ observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
input_normalization_modes:
|
||||
- observation.images.top: mean_std
|
||||
+ observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
```
|
||||
|
||||
Here we've accounted for the following:
|
||||
- PushT uses "observation.image" for its image key.
|
||||
- PushT provides smaller images.
|
||||
|
||||
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
|
||||
|
||||
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
|
||||
|
||||
```bash
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
|
||||
```
|
||||
|
||||
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
|
||||
|
||||
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
|
||||
```bash
|
||||
python lerobot/scripts/train.py policy=act_pusht env=pusht
|
||||
```
|
||||
|
||||
Notice that this is much the same as the command that failed at the start of the tutorial, only:
|
||||
- Now we are using `policy=act_pusht` to point to our new configuration file.
|
||||
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
|
||||
|
||||
Hurrah! You're now training ACT for the PushT environment.
|
||||
|
||||
---
|
||||
|
||||
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
|
||||
|
||||
Happy coding! 🤗
|
||||
@@ -9,82 +9,76 @@ on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
loss_cumsum += loss.item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
|
||||
@@ -2,10 +2,9 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi
|
||||
import torch
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
@@ -45,7 +44,7 @@ PUSHT_FEATURES = {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channels",
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
@@ -90,9 +89,9 @@ def calculate_coverage(zarr_data):
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||
coverage = np.zeros((num_frames,))
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
@@ -118,7 +117,7 @@ def calculate_coverage(zarr_data):
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
@@ -135,8 +134,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
@@ -149,10 +148,6 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||
image = image.astype(np.uint8)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
@@ -180,30 +175,28 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
idx = i + (frame_idx < num_frames - 1)
|
||||
frame = {
|
||||
"action": action[i],
|
||||
"action": torch.from_numpy(action[i]),
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[idx : idx + 1],
|
||||
"next.success": success[idx : idx + 1],
|
||||
"task": PUSHT_TASK,
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
}
|
||||
|
||||
frame["observation.state"] = agent_pos[i]
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = keypoints[i]
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
else:
|
||||
frame["observation.image"] = image[i]
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -225,5 +218,5 @@ if __name__ == "__main__":
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# breakpoint()
|
||||
|
||||
@@ -58,6 +58,7 @@ available_tasks_per_env = {
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -85,6 +86,23 @@ available_datasets_per_env = {
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -203,6 +221,7 @@ available_policies_per_env = {
|
||||
"xarm": ["tdmpc"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
@@ -1,54 +0,0 @@
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -13,164 +13,202 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
Lower the power for less number of samples.
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
For default arguments, we have:
|
||||
- from 1 to ~500, num_samples=100
|
||||
- at 1000, num_samples=177
|
||||
- at 2000, num_samples=299
|
||||
- at 5000, num_samples=594
|
||||
- at 10000, num_samples=1000
|
||||
- at 20000, num_samples=1681
|
||||
Note: We assume the images are in channel first format
|
||||
"""
|
||||
if dataset_len < min_num_samples:
|
||||
min_num_samples = dataset_len
|
||||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
def sample_indices(data_len: int) -> list[int]:
|
||||
num_samples = estimate_num_samples(data_len)
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
stats_patterns = {}
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
|
||||
images[i] = img
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
while counts.ndim < means.ndim:
|
||||
counts = np.expand_dims(counts, axis=-1)
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
# Compute the weighted mean
|
||||
weighted_means = means * counts
|
||||
total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# Compute the variance using the parallel algorithm
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
For instance:
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_mean = (mean of all data, weighted by counts)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
|
||||
_assert_type_and_shape(stats_list)
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
@@ -14,105 +14,103 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
def resolve_delta_timestamps(cfg):
|
||||
"""Resolves delta_timestamps config key (in-place) by using `eval`.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
Returns:
|
||||
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
|
||||
{
|
||||
"observation.state": [-0.04, -0.02, 0]
|
||||
"observation.action": [-0.02, 0, 0.02]
|
||||
}
|
||||
returns `None` if the the resulting dict is empty.
|
||||
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
|
||||
the data type of its values).
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
delta_timestamps = None
|
||||
|
||||
return delta_timestamps
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
|
||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
cfg.dataset_repo_id,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -38,40 +38,22 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
)
|
||||
|
||||
if image_array.dtype != np.uint8:
|
||||
if range_check:
|
||||
max_ = image_array.max().item()
|
||||
min_ = image_array.min().item()
|
||||
if max_ > 1.0 or min_ < 0.0:
|
||||
raise ValueError(
|
||||
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||
"provide a uint8 image with values in the range [0, 255]."
|
||||
)
|
||||
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_pil_image(image)
|
||||
img = image_array_to_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
|
||||
@@ -13,57 +13,50 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
serialize_dict,
|
||||
write_json,
|
||||
write_parquet,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -73,7 +66,9 @@ from lerobot.common.datasets.video_utils import (
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -81,36 +76,19 @@ class LeRobotDatasetMetadata:
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def load_metadata(self):
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -120,16 +98,21 @@ class LeRobotDatasetMetadata:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
revision=self._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
@@ -219,65 +202,54 @@ class LeRobotDatasetMetadata:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
return self.task_to_task_index.get(task, None)
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def add_task(self, task: str):
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
) -> None:
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
||||
# ep_stats = serialize_dict(ep_stats)
|
||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
def write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
@@ -287,6 +259,8 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
@@ -312,12 +286,12 @@ class LeRobotDatasetMetadata:
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
@@ -332,21 +306,12 @@ class LeRobotDatasetMetadata:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.local_files_only = True
|
||||
return obj
|
||||
|
||||
|
||||
@@ -359,9 +324,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -371,7 +335,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||
internet connection).
|
||||
internet connection), in that case, use local_files_only=True.
|
||||
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||
@@ -391,7 +355,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
- info contains various information about the dataset like shapes, keys, fps etc.
|
||||
- stats stores the dataset statistics of the different modalities for normalization
|
||||
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||
task-conditioned training.
|
||||
task-conditionned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
|
||||
@@ -453,28 +417,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. Defaults to current codebase version tag.
|
||||
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
||||
are already present in the local cache, this will be faster. However, files loaded might not
|
||||
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
||||
False.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
||||
will be made. Defaults to False.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -483,92 +443,64 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
logging.warning(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||
"Consolidating first."
|
||||
)
|
||||
self.consolidate()
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
create_repo(
|
||||
repo_id=self.repo_id,
|
||||
private=private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
if branch:
|
||||
hub_api.create_branch(
|
||||
repo_id=self.repo_id,
|
||||
branch=branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
upload_kwargs = {
|
||||
"repo_id": self.repo_id,
|
||||
"folder_path": self.root,
|
||||
"repo_type": "dataset",
|
||||
"revision": branch,
|
||||
"allow_patterns": allow_patterns,
|
||||
"ignore_patterns": ignore_patterns,
|
||||
}
|
||||
if upload_large_folder:
|
||||
hub_api.upload_large_folder(**upload_kwargs)
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -578,10 +510,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
revision=self.meta._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
@@ -595,23 +528,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
@@ -623,15 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
@@ -698,7 +617,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -728,7 +647,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -745,10 +665,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -764,13 +680,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
ep_buffer["task"] = []
|
||||
for key in self.features:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
@@ -792,35 +705,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# Convert torch to numpy if needed
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
validate_frame(frame, self.features)
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
|
||||
if key not in self.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
raise ValueError(key)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -828,95 +731,80 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
if key == "index":
|
||||
episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
@@ -945,7 +833,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
@@ -985,6 +873,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -1013,7 +933,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.revision = None
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
@@ -1023,8 +943,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = obj.create_hf_dataset()
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
@@ -1049,11 +975,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.root = Path(root) if root else LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
@@ -1066,6 +993,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
@@ -1093,10 +1021,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Send warning if raw_dir isn't well formatted
|
||||
# Send warning if raw_dir isn't well formated
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
|
||||
@@ -68,11 +68,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far apart.
|
||||
# are too far appart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
@@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formatted
|
||||
# sanity check the video paths are well formated
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
@@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formatted
|
||||
# sanity check the video path is well formated
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
@@ -66,8 +65,6 @@ class RandomSubsetApply(Transform):
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
self.selected_transforms = None
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
@@ -75,9 +72,9 @@ class RandomSubsetApply(Transform):
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in self.selected_transforms:
|
||||
for transform in selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
@@ -132,118 +129,69 @@ class SharpnessJitter(Transform):
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageTransformConfig:
|
||||
"""
|
||||
For each transform, the following parameters are available:
|
||||
weight: This represents the multinomial probability (with no replacement)
|
||||
used for sampling the transform. If the sum of the weights is not 1,
|
||||
they will be normalized.
|
||||
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
||||
custom transform defined here.
|
||||
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
(following uniform distribution) when it's applied.
|
||||
"""
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
|
||||
weight: float = 1.0
|
||||
type: str = "Identity"
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
@dataclass
|
||||
class ImageTransformsConfig:
|
||||
"""
|
||||
These transforms are all using standard torchvision.transforms.v2
|
||||
You can find out how these transformations affect images here:
|
||||
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
We use a custom RandomSubsetApply container to sample them.
|
||||
"""
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: bool = False
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number_of_available_transforms].
|
||||
max_num_transforms: int = 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: bool = False
|
||||
tfs: dict[str, ImageTransformConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.8, 1.2)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.8, 1.2)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 1.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (-0.05, 0.05)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
|
||||
self.weights = []
|
||||
self.transforms = {}
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
if tf_cfg.weight <= 0.0:
|
||||
continue
|
||||
|
||||
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
||||
self.weights.append(tf_cfg.weight)
|
||||
|
||||
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
||||
if n_subset == 0 or not cfg.enable:
|
||||
self.tf = v2.Identity()
|
||||
else:
|
||||
self.tf = RandomSubsetApply(
|
||||
transforms=list(self.transforms.values()),
|
||||
p=self.weights,
|
||||
n_subset=n_subset,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -27,29 +27,20 @@ from typing import Any
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
@@ -107,39 +98,18 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
return getter
|
||||
|
||||
for key in split_keys[1:]:
|
||||
getter = getter[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
return dataset
|
||||
dataset.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
@@ -170,10 +140,6 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
@@ -181,76 +147,29 @@ def load_info(local_dir: Path) -> dict:
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
if "float" in dtype:
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
@@ -269,95 +188,77 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
except packaging.version.InvalidVersion:
|
||||
return False
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str,
|
||||
version_to_check: str | packaging.version.Version,
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
else version_to_check
|
||||
)
|
||||
v_current = (
|
||||
packaging.version.parse(current_version)
|
||||
if not isinstance(current_version, packaging.version.Version)
|
||||
else current_version
|
||||
)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(packaging.version.InvalidVersion):
|
||||
repo_versions.append(packaging.version.parse(ref))
|
||||
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
if lower_major:
|
||||
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
||||
|
||||
upper_versions = [v for v in hub_versions if v > target_version]
|
||||
assert len(upper_versions) > 0
|
||||
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
@@ -369,20 +270,11 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1:
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
@@ -397,37 +289,6 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
@@ -453,85 +314,88 @@ def create_empty_dataset_info(
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lengths),
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
episode_data_index: dict[str, np.ndarray],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
Returns:
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
"timestamps and episode_indices should have the same shape. "
|
||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||
)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# Consecutive differences
|
||||
diffs = np.diff(timestamps)
|
||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
# Check if all remaining diffs are within tolerance
|
||||
if not np.all(filtered_within_tolerance):
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = np.arange(len(diffs))
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item()
|
||||
if hasattr(episode_indices[idx], "item")
|
||||
else episode_indices[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues during data collection.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
@@ -572,7 +436,7 @@ def check_delta_timestamps(
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
|
||||
return delta_indices
|
||||
|
||||
@@ -696,118 +560,3 @@ class IterableNamespace(SimpleNamespace):
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
@@ -26,14 +26,13 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
# spellchecker:off
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
@@ -46,7 +45,7 @@ ALOHA_MOBILE_INFO = {
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
@@ -857,7 +856,6 @@ DATASETS = {
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
# spellchecker:on
|
||||
|
||||
|
||||
def batch_convert():
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
get_hub_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
@@ -141,8 +141,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
@@ -153,18 +152,19 @@ V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
if robot_cfg.type in ["aloha", "koch"]:
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
||||
for arm in robot_cfg.follower_arms
|
||||
for motor in robot_cfg.follower_arms[arm].motors
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
||||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
@@ -173,7 +173,7 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg.type,
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
@@ -203,10 +203,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -227,11 +224,11 @@ def get_features_from_hf_dataset(
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channels"]
|
||||
names = ["height", "width", "channel"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channels"]
|
||||
names = ["height", "width", "channel"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -439,11 +436,11 @@ def convert_dataset(
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: RobotConfig | None = None,
|
||||
robot_config: dict | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -535,7 +532,7 @@ def convert_dataset(
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
robot_type = robot_config["robot_type"]
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
@@ -624,10 +621,16 @@ def main():
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
@@ -652,10 +655,8 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
del args.robot
|
||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
|
||||
del args.robot_config, args.robot_overrides
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import get_dataset_config_info
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
hub_api = HfApi()
|
||||
|
||||
|
||||
def fix_dataset(repo_id: str) -> str:
|
||||
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||
return f"{repo_id}: skipped (not in {V20})."
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
commit_info = hub_api.upload_file(
|
||||
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||
path_in_repo=INFO_PATH,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V20,
|
||||
commit_message="Remove 'language_instruction'",
|
||||
create_pr=True,
|
||||
)
|
||||
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||
|
||||
|
||||
def batch_fix():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
status = fix_dataset(repo_id)
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_fix()
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
hub_api = HfApi()
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||
status = f"{repo_id}: success (already in {V21})."
|
||||
else:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
@@ -1,85 +0,0 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
@@ -69,11 +69,11 @@ def decode_video_frames_torchvision(
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
@@ -1,142 +0,0 @@
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
task: str | None = None
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
@@ -14,56 +14,136 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from mani_skill.utils import common
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
"""
|
||||
if n_envs < 1:
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
package_name = f"gym_{cfg.type}"
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
@@ -18,13 +18,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.utils.utils import get_channel_first_image_shape
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -33,6 +28,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
@@ -40,7 +38,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
@@ -56,6 +53,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -68,21 +67,36 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
else:
|
||||
feature = ft
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
return policy_features
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
245
lerobot/common/logger.py
Normal file
245
lerobot/common/logger.py
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.name}",
|
||||
f"dataset:{cfg.dataset_repo_id}",
|
||||
f"env:{cfg.env.name}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── .hydra # hydra's configuration cache
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
job_name: The WandB job name.
|
||||
"""
|
||||
self._cfg = cfg
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=wandb_job_name,
|
||||
notes=cfg.get("wandb", {}).get("notes"),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=log_dir,
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError(
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
@@ -1 +0,0 @@
|
||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(
|
||||
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
||||
) -> tuple[Optimizer, LRScheduler | None]:
|
||||
"""Generates the optimizer and scheduler based on configs.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
||||
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
||||
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
lr: float
|
||||
weight_decay: float
|
||||
grad_clip_norm: float
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@classmethod
|
||||
def default_choice_name(cls) -> str | None:
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adam")
|
||||
@dataclass
|
||||
class AdamConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("adamw")
|
||||
@dataclass
|
||||
class AdamWConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
betas: tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("sgd")
|
||||
@dataclass
|
||||
class SGDConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
momentum: float = 0.0
|
||||
dampening: float = 0.0
|
||||
nesterov: bool = False
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
||||
)
|
||||
loaded_state_dict["param_groups"] = param_groups
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("diffuser")
|
||||
@dataclass
|
||||
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
name: str = "cosine"
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
num_vqvae_training_steps: int
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
if current_step < self.num_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
adjusted_step = current_step - self.num_vqvae_training_steps
|
||||
if adjusted_step < self.num_warmup_steps:
|
||||
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
peak_lr: float
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
@@ -1,5 +0,0 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -15,14 +15,9 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("act")
|
||||
@dataclass
|
||||
class ACTConfig(PreTrainedConfig):
|
||||
class ACTConfig:
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
@@ -64,7 +59,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
@@ -95,11 +90,28 @@ class ACTConfig(PreTrainedConfig):
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -132,14 +144,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
@@ -159,28 +164,8 @@ class ACTConfig(PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -29,27 +29,32 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
config_class = ACTConfig
|
||||
name = "act"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
config: ACTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -59,46 +64,30 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config: ACTConfig = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
|
||||
# Should we remove this and just `return self.parameters()`?
|
||||
return [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.config.optimizer_lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
@@ -117,11 +106,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -144,14 +131,12 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -169,11 +154,11 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss = l1_loss
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss, loss_dict
|
||||
return loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
@@ -303,30 +288,31 @@ class ACT(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
self.config.action_feature.shape[0],
|
||||
config.dim_model,
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
@@ -334,7 +320,7 @@ class ACT(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.config.image_features:
|
||||
if self.use_images:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
@@ -351,27 +337,27 @@ class ACT(nn.Module):
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
if self.use_env_state:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
self.config.env_state_feature.shape[0], config.dim_model
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.config.image_features:
|
||||
if self.use_images:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
n_1d_tokens += 1
|
||||
if self.config.env_state_feature:
|
||||
if self.use_env_state:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
@@ -379,7 +365,7 @@ class ACT(nn.Module):
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -394,13 +380,13 @@ class ACT(nn.Module):
|
||||
|
||||
`batch` should have the following structure:
|
||||
{
|
||||
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
[image_features]: (B, n_cameras, C, H, W) batch of images.
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
[env_state_feature]: (B, env_dim) batch of environment states.
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
|
||||
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
Returns:
|
||||
@@ -409,9 +395,9 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert "action" in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
@@ -425,12 +411,12 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
@@ -444,7 +430,7 @@ class ACT(nn.Module):
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
@@ -477,16 +463,16 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
if self.use_env_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.config.image_features:
|
||||
if self.use_images:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
|
||||
@@ -16,15 +16,9 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig(PreTrainedConfig):
|
||||
class DiffusionConfig:
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -68,7 +62,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -99,7 +93,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
@@ -108,17 +102,26 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -151,23 +154,39 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
@@ -188,50 +207,3 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -31,32 +31,35 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
config: DiffusionConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -66,16 +69,18 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
@@ -83,10 +88,10 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.reset()
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.diffusion.parameters()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
@@ -94,9 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
if len(self.expected_image_keys) > 0:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@@ -122,11 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -143,18 +146,15 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
return {"loss": loss}
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
@@ -176,9 +176,12 @@ class DiffusionModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
@@ -186,8 +189,9 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
@@ -216,7 +220,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -238,10 +242,10 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self._use_images:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
@@ -268,8 +272,8 @@ class DiffusionModel(nn.Module):
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
@@ -439,7 +443,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encodes an RGB image into a 1D feature vector.
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
@@ -478,16 +482,19 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -604,7 +611,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
@@ -659,7 +666,7 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
||||
@@ -13,141 +13,104 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
|
||||
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||
def list_to_tuple(item):
|
||||
return tuple(item) if isinstance(item, list) else item
|
||||
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: list_to_tuple(v)
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
return ACTPolicy, ACTConfig
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return PI0Policy
|
||||
return SACPolicy, SACConfig
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
device: str | torch.device,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
device (str): the device to load the policy onto.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
||||
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
||||
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
||||
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
||||
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
||||
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
# raise ValueError(
|
||||
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
# )
|
||||
|
||||
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
||||
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
||||
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
|
||||
# you want this op to be added in priority during the prototype phase of this feature, please comment on
|
||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||
# slower than running natively on MPS.
|
||||
if cfg.type == "vqbet" and str(device) == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cls = get_policy_class(cfg.type)
|
||||
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
"You are instantiating a policy from scratch and its features are parsed from an environment "
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
|
||||
policy.to(device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
return policy
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -0,0 +1,151 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x):
|
||||
if self.config.num_classes == 2:
|
||||
return (self.forward(x).probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
@@ -13,16 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
@@ -37,16 +34,12 @@ def create_stats_buffers(
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
@@ -59,7 +52,7 @@ def create_stats_buffers(
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
@@ -68,7 +61,7 @@ def create_stats_buffers(
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
@@ -78,29 +71,17 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
@@ -118,8 +99,8 @@ class Normalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -141,10 +122,10 @@ class Normalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -152,24 +133,16 @@ class Normalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
@@ -179,7 +152,7 @@ class Normalize(nn.Module):
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
@@ -191,8 +164,8 @@ class Unnormalize(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -214,11 +187,11 @@ class Unnormalize(nn.Module):
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -226,23 +199,16 @@ class Unnormalize(nn.Module):
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
@@ -250,5 +216,5 @@ class Unnormalize(nn.Module):
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,68 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda"
|
||||
dataset_repo_id = "danaaubakirova/koch_test"
|
||||
# model_name = "pi0_base"
|
||||
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = "lerobot/pi0"
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
warmup_iters = 10
|
||||
benchmark_iters = 30
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
torch.cuda.synchronize()
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(benchmark_iters):
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
end_event.record()
|
||||
|
||||
# Synchronize and measure time
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
||||
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
@@ -1,117 +0,0 @@
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
print(f"Shape: {tensor.shape}")
|
||||
print(f"Mean: {tensor.mean().item()}")
|
||||
print(f"Std: {tensor.std().item()}")
|
||||
print(f"Min: {tensor.min().item()}")
|
||||
print(f"Max: {tensor.max().item()}")
|
||||
|
||||
|
||||
def main():
|
||||
num_motors = 14
|
||||
device = "cuda"
|
||||
# model_name = "pi0_aloha_towel"
|
||||
model_name = "pi0_aloha_sim"
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
dataset_repo_id = "lerobot/aloha_static_towel"
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
with open(save_dir / "example.pkl", "rb") as f:
|
||||
example = pickle.load(f)
|
||||
with open(save_dir / "outputs.pkl", "rb") as f:
|
||||
outputs = pickle.load(f)
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
elif isinstance(batch[key], str):
|
||||
batch[key] = [batch[key]]
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key]}")
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
||||
|
||||
from lerobot.common import policies # noqa
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
# print("losses")
|
||||
# display(loss_dict["losses_after_forward"])
|
||||
# print("pi_losses")
|
||||
# display(pi_losses)
|
||||
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
action = policy.select_action(batch, noise=noise)
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
print("pi_actions")
|
||||
display(pi_actions)
|
||||
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
||||
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
||||
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,70 +0,0 @@
|
||||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
def get_paligemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
||||
|
||||
image_size = 224 # image_sizes[variant]
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 2048,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 16384,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
vision_config = {
|
||||
"torch_dtype": precision,
|
||||
"image_size": image_size,
|
||||
"patch_size": patch_size,
|
||||
"num_image_tokens": num_image_tokens,
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
def get_gemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 1024,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
final_config = GemmaConfig()
|
||||
final_config.update(text_config)
|
||||
return final_config
|
||||
@@ -1,423 +0,0 @@
|
||||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required libraries.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Example downloading parameters:
|
||||
```bash
|
||||
python
|
||||
>>> import openpi.shared.download as download
|
||||
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
||||
>>> download.maybe_download(path)
|
||||
```
|
||||
|
||||
Converting pi0_base:
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
||||
```
|
||||
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import torch
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
get_gemma_config,
|
||||
get_paligemma_config,
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# fmt: off
|
||||
# patch embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
||||
# positional embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
||||
-1, config.vision_config.hidden_size
|
||||
)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
||||
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
||||
|
||||
# multimodal projector
|
||||
|
||||
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
||||
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
||||
|
||||
# text decoder (gemma)
|
||||
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
||||
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
||||
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
# fmt: on
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
]:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
||||
# fmt: off
|
||||
# text decoder (gemma)
|
||||
# no embedding vector, the expert just has the decoder layers
|
||||
|
||||
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
||||
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
||||
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
||||
|
||||
# fmt: on
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def flatten_for_memory(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_memory(v, new_key))
|
||||
else:
|
||||
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
||||
return out
|
||||
|
||||
|
||||
def flatten_for_npz(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_npz(v, new_key))
|
||||
else:
|
||||
# bf16/f32 here?
|
||||
out[new_key] = np.array(v)
|
||||
return out
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
||||
params_path = pathlib.Path(checkpoint_dir).resolve()
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
|
||||
metadata = checkpointer.metadata(params_path)
|
||||
print("Metadata keys:", list(metadata.keys()))
|
||||
|
||||
params_name = "params"
|
||||
|
||||
item = {params_name: metadata[params_name]}
|
||||
device = jax.local_devices()[0] # Use the first local device
|
||||
sharding = SingleDeviceSharding(device)
|
||||
restored = checkpointer.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree_util.tree_map(
|
||||
lambda _: ocp.ArrayRestoreArgs(
|
||||
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
||||
sharding=sharding,
|
||||
),
|
||||
item,
|
||||
),
|
||||
transforms={},
|
||||
),
|
||||
)
|
||||
params = restored[params_name]
|
||||
|
||||
# get params for PaliGemma
|
||||
pali_params = params["PaliGemma"]
|
||||
del params["PaliGemma"]
|
||||
pali_params_flat = flatten_for_npz(pali_params)
|
||||
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
||||
|
||||
|
||||
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
"""Update dictionary keys by adding a prefix."""
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_config = get_paligemma_config(precision)
|
||||
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
||||
initial_params["paligemma_params"], paligemma_config
|
||||
)
|
||||
|
||||
# Process Gemma weights (at this stage they are unused)
|
||||
gemma_config = get_gemma_config(precision)
|
||||
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
||||
|
||||
# Instantiate model from configs
|
||||
|
||||
if "pi0_aloha_sim" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
# load state dict
|
||||
torch_dtype = PRECISIONS[precision]
|
||||
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
||||
pi0_model = pi0_model.to(torch_dtype)
|
||||
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
|
||||
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
||||
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
||||
|
||||
# assert that model loads properly
|
||||
del pi0_model
|
||||
PI0Policy.from_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
||||
type=str,
|
||||
help="Path to the ocdbt checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
default="float32",
|
||||
type=str,
|
||||
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
||||
)
|
||||
# tokenizer is identical to paligemma, it appears
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_hub_id",
|
||||
default="google/paligemma-3b-pt-224",
|
||||
type=str,
|
||||
help="Hub path to the tokenizer to save",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to save converted weights to",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
# @torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_att_heads = 8
|
||||
num_key_value_heads = 1
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
query_states = query_states.to(torch.float32)
|
||||
key_states = key_states.to(torch.float32)
|
||||
value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output
|
||||
@@ -1,732 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0Policy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0Config
|
||||
name = "pi0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
class PI0FlowMatching(nn.Module):
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌┴─────┐ │
|
||||
│ kv cache │Gemma │ │
|
||||
│ ┌──────────►│Expert│ │
|
||||
│ │ │ │ │
|
||||
│ ┌┴────────┐ │x 10 │ │
|
||||
│ │ │ └▲──▲──┘ │
|
||||
│ │PaliGemma│ │ │ │
|
||||
│ │ │ │ robot state │
|
||||
│ │ │ noise │
|
||||
│ └▲──▲─────┘ │
|
||||
│ │ │ │
|
||||
│ │ image(s) │
|
||||
│ language tokens │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# TODO: remove for loop
|
||||
for (
|
||||
img,
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
@@ -1,403 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||
model_type = "PaliGemmaWithExpertModel"
|
||||
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paligemma_config: dict | None = None,
|
||||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
# Default config from Pi0
|
||||
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 257152,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"torch_dtype": "float32",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
elif isinstance(self.paligemma_config, dict):
|
||||
# Override Pi0 default config for PaliGemma
|
||||
if "model_type" not in gemma_expert_config:
|
||||
paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
if gemma_expert_config is None:
|
||||
# Default config from Pi0
|
||||
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.48.1",
|
||||
use_cache=True,
|
||||
vocab_size=257152,
|
||||
)
|
||||
elif isinstance(self.gemma_expert_config, dict):
|
||||
# Override Pi0 default config for Gemma Expert
|
||||
if "model_type" not in gemma_expert_config:
|
||||
gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
raise ValueError(
|
||||
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
)
|
||||
|
||||
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
raise ValueError(
|
||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for params in self.paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for params in self.paligemma.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
"gemma_expert.model.layers",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
head_dim = self.paligemma.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is None:
|
||||
continue
|
||||
layer = models[i].layers[layer_idx]
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
|
||||
query_states = apply_rope(query_states, position_ids)
|
||||
key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
|
||||
if hidden_states is not None:
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.config.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.config.attention_implementation == "flex":
|
||||
attention_interface = flex_attention_forward
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
75
lerobot/common/policies/policy_protocol.py
Normal file
75
lerobot/common/policies/policy_protocol.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
subclass a base class.
|
||||
|
||||
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||
how to implement new policies.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy.
|
||||
|
||||
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PolicyWithUpdate(Policy, Protocol):
|
||||
def update(self):
|
||||
"""An update method that is to be called after a training optimization step.
|
||||
|
||||
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
|
||||
target model, or incrementing an internal buffer).
|
||||
"""
|
||||
@@ -1,187 +0,0 @@
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor
|
||||
from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
DEFAULT_POLICY_CARD = """
|
||||
---
|
||||
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
|
||||
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
||||
"""
|
||||
|
||||
|
||||
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base class for policy models.
|
||||
"""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
policy.to(map_location)
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return model
|
||||
|
||||
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
||||
# card = ModelCard.from_template(
|
||||
# card_data=self._hub_mixin_info.model_card_data,
|
||||
# template_str=self._hub_mixin_info.model_card_template,
|
||||
# repo_url=self._hub_mixin_info.repo_url,
|
||||
# docs_url=self._hub_mixin_info.docs_url,
|
||||
# **kwargs,
|
||||
# )
|
||||
# return card
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optim_params(self) -> dict:
|
||||
"""
|
||||
Returns the policy-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): _description_
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
||||
is a Tensor, all other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
83
lerobot/common/policies/sac/configuration_sac.py
Normal file
83
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = False
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
policy_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
)
|
||||
571
lerobot/common/policies/sac/modeling_sac.py
Normal file
571
lerobot/common/policies/sac/modeling_sac.py
Normal file
@@ -0,0 +1,571 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = {}
|
||||
for outer_key, inner_dict in config.output_normalization_params.items():
|
||||
output_normalization_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
output_normalization_params[outer_key][key] = torch.tensor(value)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
device=device,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cpu",
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
sequential=nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(
|
||||
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
@@ -16,14 +16,9 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("tdmpc")
|
||||
@dataclass
|
||||
class TDMPCConfig(PreTrainedConfig):
|
||||
class TDMPCConfig:
|
||||
"""Configuration class for TDMPCPolicy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
@@ -76,7 +71,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
|
||||
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
@@ -107,19 +102,27 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ENV": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
@@ -156,27 +159,32 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
# Target model.
|
||||
target_model_momentum: float = 0.995
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 3e-4
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
@@ -186,35 +194,3 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(lr=self.optimizer_lr)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# There should only be one image key.
|
||||
if len(self.image_features) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
|
||||
)
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
image_ft = next(iter(self.image_features.values()))
|
||||
if image_ft.shape[-2] != image_ft.shape[-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(self.horizon + 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return list(range(self.horizon))
|
||||
|
||||
@@ -33,16 +33,21 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -60,10 +65,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -71,27 +77,41 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.reset()
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -102,9 +122,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self.config.env_state_feature:
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
@@ -114,9 +134,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -131,9 +151,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self.config.image_features:
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self.config.env_state_feature:
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
@@ -176,7 +196,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
@@ -195,7 +215,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -208,7 +228,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
@@ -302,7 +322,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
@@ -310,16 +330,16 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
@@ -327,7 +347,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
@@ -340,7 +360,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
@@ -495,16 +515,17 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return loss, info
|
||||
return info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
@@ -522,7 +543,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -533,7 +554,7 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -548,12 +569,12 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
|
||||
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
|
||||
)
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -594,9 +615,9 @@ class TDMPCTOLD(nn.Module):
|
||||
|
||||
self.apply(_apply_fn)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(m[-1], nn.Linear), (
|
||||
"Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
)
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
@@ -693,13 +714,10 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if config.image_features:
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
next(iter(config.image_features.values())).shape[0],
|
||||
config.image_encoder_hidden_dim,
|
||||
7,
|
||||
stride=2,
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
@@ -709,8 +727,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
@@ -719,19 +738,19 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
)
|
||||
|
||||
if config.robot_state_feature:
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
if config.env_state_feature:
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -746,16 +765,12 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -48,20 +47,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
||||
"""
|
||||
Calculates the output shape of a PyTorch module given an input shape.
|
||||
|
||||
Args:
|
||||
module (nn.Module): a PyTorch module
|
||||
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
tuple: The output shape of the module.
|
||||
"""
|
||||
dummy_input = torch.zeros(size=input_shape)
|
||||
with torch.inference_mode():
|
||||
output = module(dummy_input)
|
||||
return tuple(output.shape)
|
||||
|
||||
@@ -18,15 +18,9 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vqbet")
|
||||
@dataclass
|
||||
class VQBeTConfig(PreTrainedConfig):
|
||||
class VQBeTConfig:
|
||||
"""Configuration class for VQ-BeT.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
@@ -66,7 +60,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -96,13 +90,26 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
n_action_pred_token: int = 3
|
||||
action_chunk_size: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -132,69 +139,29 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
bet_softmax_temperature: float = 0.1
|
||||
sequentially_select: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
optimizer_vqvae_lr: float = 1e-3
|
||||
optimizer_vqvae_weight_decay: float = 1e-4
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
|
||||
return VQBeTSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_vqvae_training_steps=self.n_vqvae_training_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
|
||||
# assert len(image_keys) == 1
|
||||
if not len(self.image_features) == 1:
|
||||
raise ValueError("You must provide only one image among the inputs.")
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Callable, List
|
||||
@@ -25,23 +26,29 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(PreTrainedPolicy):
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
|
||||
config_class = VQBeTConfig
|
||||
name = "vqbet"
|
||||
|
||||
def __init__(
|
||||
@@ -56,62 +63,26 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = VQBeTConfig()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
vqvae_params = (
|
||||
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.rgb_encoder.parameters())
|
||||
+ list(self.vqbet.state_projector.parameters())
|
||||
+ list(self.vqbet.rgb_feature_projector.parameters())
|
||||
+ [self.vqbet.action_token]
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if self.config.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
return [
|
||||
{
|
||||
"params": decay_params,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": self.config.optimizer_vqvae_weight_decay,
|
||||
"lr": self.config.optimizer_vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
@@ -134,7 +105,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -156,11 +127,11 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -170,16 +141,16 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return loss, {
|
||||
return {
|
||||
"loss": loss,
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
loss = loss_dict.pop("loss")
|
||||
|
||||
return loss, loss_dict
|
||||
return loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
@@ -317,14 +288,14 @@ class VQBeTModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len(self.config.image_features)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -342,7 +313,7 @@ class VQBeTModel(nn.Module):
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -379,10 +350,10 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_features) is the number of different observation modes.
|
||||
# len(self.config.input_shapes) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
||||
self.config.input_features
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
@@ -421,7 +392,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -448,7 +419,7 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.vqvae_num_layers
|
||||
* self.config.vqvae_n_embed
|
||||
* config.action_chunk_size
|
||||
* config.action_feature.shape[0],
|
||||
* config.output_shapes["action"][0],
|
||||
],
|
||||
)
|
||||
# loss
|
||||
@@ -482,10 +453,10 @@ class VQBeTHead(nn.Module):
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs) -> dict:
|
||||
def forward(self, x, **kwargs):
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallelly. Thus, the dimensions would be
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
@@ -652,6 +623,84 @@ class VQBeTHead(nn.Module):
|
||||
return loss_dict
|
||||
|
||||
|
||||
class VQBeTOptimizer(torch.optim.Adam):
|
||||
def __init__(self, policy, cfg):
|
||||
vqvae_params = (
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.rgb_encoder.parameters())
|
||||
+ list(policy.vqbet.state_projector.parameters())
|
||||
+ list(policy.vqbet.rgb_feature_projector.parameters())
|
||||
+ [policy.vqbet.action_token]
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
"params": decay_params,
|
||||
"weight_decay": cfg.training.adam_weight_decay,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": 0.0001,
|
||||
"lr": cfg.training.vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
]
|
||||
super().__init__(
|
||||
optim_groups,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
)
|
||||
|
||||
|
||||
class VQBeTScheduler(nn.Module):
|
||||
def __init__(self, optimizer, cfg):
|
||||
super().__init__()
|
||||
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
|
||||
num_warmup_steps = cfg.training.lr_warmup_steps
|
||||
num_training_steps = cfg.training.offline_steps
|
||||
num_cycles = 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < n_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
current_step = current_step - n_vqvae_training_steps
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def step(self):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
class VQBeTRgbEncoder(nn.Module):
|
||||
"""Encode an RGB image into a 1D feature vector.
|
||||
|
||||
@@ -694,15 +743,19 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
@@ -772,7 +825,7 @@ class VqVae(nn.Module):
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
||||
as well as functions to help BeT training part in training phase 2.
|
||||
"""
|
||||
|
||||
@@ -791,7 +844,7 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -803,7 +856,7 @@ class VqVae(nn.Module):
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
||||
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -819,9 +872,9 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
|
||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
@@ -203,9 +203,9 @@ class GPT(nn.Module):
|
||||
def forward(self, input, targets=None):
|
||||
device = input.device
|
||||
b, t, d = input.size()
|
||||
assert t <= self.config.gpt_block_size, (
|
||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
)
|
||||
assert (
|
||||
t <= self.config.gpt_block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
@@ -273,10 +273,10 @@ class GPT(nn.Module):
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
assert (
|
||||
len(param_dict.keys() - union_params) == 0
|
||||
), "parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
|
||||
decay = [param_dict[pn] for pn in sorted(decay)]
|
||||
@@ -289,7 +289,7 @@ class GPT(nn.Module):
|
||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
|
||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
|
||||
@@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
|
||||
# and the network should be able to reconstruct
|
||||
|
||||
if quantize_dim < self.num_quantizers:
|
||||
assert self.quantize_dropout > 0.0, (
|
||||
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
||||
)
|
||||
assert (
|
||||
self.quantize_dropout > 0.0
|
||||
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
||||
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
||||
|
||||
# get ready for gathering
|
||||
@@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
|
||||
all_indices = []
|
||||
|
||||
if return_loss:
|
||||
assert not torch.any(indices == -1), (
|
||||
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
)
|
||||
assert not torch.any(
|
||||
indices == -1
|
||||
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
@@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
|
||||
# only calculate orthogonal loss for the activated codes for this batch
|
||||
|
||||
if self.orthogonal_reg_active_codes_only:
|
||||
assert not (is_multiheaded and self.separate_codebook_per_head), (
|
||||
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
||||
)
|
||||
assert not (
|
||||
is_multiheaded and self.separate_codebook_per_head
|
||||
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
||||
unique_code_ids = torch.unique(embed_ind)
|
||||
codebook = codebook[:, unique_code_ids]
|
||||
|
||||
@@ -999,9 +999,9 @@ def gumbel_sample(
|
||||
ind = sampling_logits.argmax(dim=dim)
|
||||
one_hot = F.one_hot(ind, size).type(dtype)
|
||||
|
||||
assert not (reinmax and not straight_through), (
|
||||
"reinmax can only be turned on if using straight through gumbel softmax"
|
||||
)
|
||||
assert not (
|
||||
reinmax and not straight_through
|
||||
), "reinmax can only be turned on if using straight through gumbel softmax"
|
||||
|
||||
if not straight_through or temperature <= 0.0 or not training:
|
||||
return ind, one_hot
|
||||
@@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.gumbel_sample = gumbel_sample
|
||||
self.sample_codebook_temp = sample_codebook_temp
|
||||
|
||||
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
|
||||
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
)
|
||||
assert not (
|
||||
use_ddp and num_codebooks > 1 and kmeans_init
|
||||
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
@@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
# calculate distributed variance
|
||||
|
||||
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
||||
distributed.all_reduce(variance_number)
|
||||
batch_variance = variance_number / num_vectors
|
||||
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
||||
distributed.all_reduce(variance_numer)
|
||||
batch_variance = variance_numer / num_vectors
|
||||
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
class OpenCVCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(0, 30, 640, 480)
|
||||
OpenCVCameraConfig(0, 60, 640, 480)
|
||||
OpenCVCameraConfig(0, 90, 640, 480)
|
||||
OpenCVCameraConfig(0, 30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
camera_index: int
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
serial_number: int | None = None
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# bool is stronger than is None, since it works with empty strings
|
||||
if bool(self.name) and bool(self.serial_number):
|
||||
raise ValueError(
|
||||
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
@@ -11,13 +11,13 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
@@ -94,10 +94,7 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
config = IntelRealSenseCameraConfig(
|
||||
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
@@ -152,6 +149,51 @@ def save_images_from_cameras(
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(30, 640, 480)
|
||||
IntelRealSenseCameraConfig(60, 640, 480)
|
||||
IntelRealSenseCameraConfig(90, 640, 480)
|
||||
IntelRealSenseCameraConfig(30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
"""
|
||||
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
@@ -167,35 +209,33 @@ class IntelRealSenseCamera:
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of instantiating with a serial number:
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
# Instantiate with its serial number
|
||||
camera = IntelRealSenseCamera(128422271347)
|
||||
# Or by its name if it's unique
|
||||
camera = IntelRealSenseCamera.init_from_name("Intel RealSense D405")
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
Example of instantiating with a name if it's unique:
|
||||
```
|
||||
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
|
||||
camera = IntelRealSenseCamera(serial_number, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = IntelRealSenseCamera(serial_number, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
```
|
||||
|
||||
Example of returning depth:
|
||||
```python
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera = IntelRealSenseCamera(serial_number, use_depth=True)
|
||||
camera.connect()
|
||||
color_image, depth_map = camera.read()
|
||||
```
|
||||
@@ -203,13 +243,17 @@ class IntelRealSenseCamera:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: IntelRealSenseCameraConfig,
|
||||
serial_number: int,
|
||||
config: IntelRealSenseCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.config = config
|
||||
if config.name is not None:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.serial_number = serial_number
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
@@ -241,7 +285,8 @@ class IntelRealSenseCamera:
|
||||
elif config.rotation == 180:
|
||||
self.rotation = cv2.ROTATE_180
|
||||
|
||||
def find_serial_number_from_name(self, name):
|
||||
@classmethod
|
||||
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
|
||||
camera_infos = find_cameras()
|
||||
camera_names = [cam["name"] for cam in camera_infos]
|
||||
this_name_count = Counter(camera_names)[name]
|
||||
@@ -254,7 +299,13 @@ class IntelRealSenseCamera:
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
return cam_sn
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
|
||||
# Overwrite the config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
return cls(serial_number=cam_sn, config=config, **kwargs)
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
|
||||
@@ -9,13 +9,13 @@ import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
@@ -126,8 +126,7 @@ def save_images_from_cameras(
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = OpenCVCamera(config)
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
@@ -176,6 +175,39 @@ def save_images_from_cameras(
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCVCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(30, 640, 480)
|
||||
OpenCVCameraConfig(60, 640, 480)
|
||||
OpenCVCameraConfig(90, 640, 480)
|
||||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
@@ -195,10 +227,7 @@ class OpenCVCamera:
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
@@ -207,16 +236,25 @@ class OpenCVCamera:
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out open `camera.connect()` if these settings are not compatible with the camera
|
||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenCVCameraConfig):
|
||||
self.config = config
|
||||
self.camera_index = config.camera_index
|
||||
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.camera_index = camera_index
|
||||
self.port = None
|
||||
|
||||
# Linux uses ports for connecting to cameras
|
||||
@@ -228,7 +266,7 @@ class OpenCVCamera:
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
raise ValueError(f"Please check the provided camera_index: {camera_index}")
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
|
||||
@@ -2,12 +2,6 @@ from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import (
|
||||
CameraConfig,
|
||||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
@@ -15,39 +9,3 @@ class Camera(Protocol):
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
||||
def async_read(self) -> np.ndarray: ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
|
||||
cameras = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
if cfg.type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def make_camera(camera_type, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
config = OpenCVCameraConfig(**kwargs)
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
|
||||
config = IntelRealSenseCameraConfig(**kwargs)
|
||||
return IntelRealSenseCamera(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlConfig(draccus.ChoiceRegistry):
|
||||
pass
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("calibrate")
|
||||
@dataclass
|
||||
class CalibrateControlConfig(ControlConfig):
|
||||
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
|
||||
arms: list[str] | None = None
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("teleoperate")
|
||||
@dataclass
|
||||
class TeleoperateControlConfig(ControlConfig):
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@dataclass
|
||||
class RecordControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||
device: str | None = None # cuda | cpu | mps
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int | None = None
|
||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||
warmup_time_s: int | float = 10
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("control.policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("control.policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
# When no device or use_amp are given, use the one from training config.
|
||||
if self.device is None or self.use_amp is None:
|
||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||
if self.device is None:
|
||||
self.device = train_cfg.device
|
||||
if self.use_amp is None:
|
||||
self.use_amp = train_cfg.use_amp
|
||||
|
||||
# Automatically switch to available device if necessary
|
||||
if not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("replay")
|
||||
@dataclass
|
||||
class ReplayControlConfig(ControlConfig):
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# Index of the episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the dataset fps.
|
||||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("remote_robot")
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlPipelineConfig:
|
||||
robot: RobotConfig
|
||||
control: ControlConfig
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["control.policy"]
|
||||
@@ -11,16 +11,20 @@ from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
@@ -32,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
@@ -86,6 +90,10 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def has_method(_object: object, method_name: str):
|
||||
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
@@ -113,14 +121,22 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -145,6 +161,13 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -154,6 +177,26 @@ def init_keyboard_listener():
|
||||
return listener, events
|
||||
|
||||
|
||||
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
use_amp = hydra_cfg.use_amp
|
||||
policy_fps = hydra_cfg.env.fps
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
return policy, policy_fps, device, use_amp
|
||||
|
||||
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
@@ -182,7 +225,6 @@ def record_episode(
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
@@ -195,7 +237,6 @@ def record_episode(
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
|
||||
@@ -208,10 +249,9 @@ def control_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -226,15 +266,9 @@ def control_loop(
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = get_safe_torch_device(device)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -253,7 +287,9 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
@@ -275,18 +311,34 @@ def control_loop(
|
||||
break
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s, fps):
|
||||
def reset_environment(robot, events, reset_time_s):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
# TODO(alibets): allow for teleop during reset
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=reset_time_s,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=True,
|
||||
)
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
@@ -300,31 +352,35 @@ def stop_recording(robot, listener, display_cameras):
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
def sanity_check_dataset_name(repo_id, policy):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
||||
if not dataset_name.startswith("eval_") and policy is not None:
|
||||
raise ValueError(
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
("features", dataset.features, features_from_robot),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("dynamixel")
|
||||
@dataclass
|
||||
class DynamixelMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("feetech")
|
||||
@dataclass
|
||||
class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
@@ -8,7 +8,6 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
@@ -242,7 +241,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -253,6 +252,7 @@ class JointOutOfRangeError(Exception):
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
@@ -274,11 +274,10 @@ class DynamixelMotorsBus:
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
config = DynamixelMotorsBusConfig(
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = DynamixelMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
@@ -294,14 +293,23 @@ class DynamixelMotorsBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DynamixelMotorsBusConfig,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
):
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
@@ -610,7 +618,7 @@ class DynamixelMotorsBus:
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
@@ -221,7 +220,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -253,11 +252,10 @@ class FeetechMotorsBus:
|
||||
motor_index = 6
|
||||
motor_model = "sts3215"
|
||||
|
||||
config = FeetechMotorsBusConfig(
|
||||
motors_bus = FeetechMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = FeetechMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
@@ -273,14 +271,23 @@ class FeetechMotorsBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FeetechMotorsBusConfig,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
mock=False,
|
||||
):
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.mock = mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
@@ -591,7 +598,7 @@ class FeetechMotorsBus:
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
@@ -632,7 +639,7 @@ class FeetechMotorsBus:
|
||||
track["prev"][idx] = values[i]
|
||||
continue
|
||||
|
||||
# Detect a full rotation occurred
|
||||
# Detect a full rotation occured
|
||||
if abs(track["prev"][idx] - values[i]) > 2048:
|
||||
# Position went below 0 and got reset to 4095
|
||||
if track["prev"][idx] < values[i]:
|
||||
@@ -717,10 +724,6 @@ class FeetechMotorsBus:
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# Very Important to flush the buffer!
|
||||
self.port_handler.ser.reset_output_buffer()
|
||||
self.port_handler.ser.reset_input_buffer()
|
||||
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
@@ -14,40 +8,3 @@ class MotorsBus(Protocol):
|
||||
def revert_calibration(self): ...
|
||||
def read(self): ...
|
||||
def write(self): ...
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
if cfg.type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
motors_buses[key] = DynamixelMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return motors_buses
|
||||
|
||||
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
config = DynamixelMotorsBusConfig(**kwargs)
|
||||
return DynamixelMotorsBus(config)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
config = FeetechMotorsBusConfig(**kwargs)
|
||||
return FeetechMotorsBus(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user