Compare commits
34 Commits
user/rcade
...
chore/bump
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d420de118 | ||
|
|
bf6f89a5b5 | ||
|
|
8861546ad8 | ||
|
|
9c1a893ee3 | ||
|
|
e81c36cf74 | ||
|
|
ed83cbd4f2 | ||
|
|
2a33b9ad87 | ||
|
|
6e85aa13ec | ||
|
|
af05a1725c | ||
|
|
800c4a847f | ||
|
|
bba8c4c0d4 | ||
|
|
68b369e321 | ||
|
|
8d60ac3ffc | ||
|
|
659ec4434d | ||
|
|
da265ca920 | ||
|
|
a1809ad3de | ||
|
|
8699a28be0 | ||
|
|
65db5afe1c | ||
|
|
75d5fa4604 | ||
|
|
e64fad2224 | ||
|
|
eecf32e77a | ||
|
|
3354d919fc | ||
|
|
aca464ca72 | ||
|
|
fe483b1d0d | ||
|
|
ddeade077e | ||
|
|
c4c2ce04e7 | ||
|
|
2cb0bf5d41 | ||
|
|
b86a2c0b47 | ||
|
|
c574eb4984 | ||
|
|
1e49cc4d60 | ||
|
|
e71095960f | ||
|
|
90e099b39f | ||
|
|
334deb985d | ||
|
|
8548a87bd4 |
12
.github/workflows/build-docker-images.yml
vendored
12
.github/workflows/build-docker-images.yml
vendored
@@ -8,6 +8,8 @@ on:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -25,11 +27,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -60,11 +65,14 @@ jobs:
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
@@ -89,9 +97,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
2
.github/workflows/nightly-tests.yml
vendored
2
.github/workflows/nightly-tests.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
permissions: {}
|
||||
|
||||
# env:
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
jobs:
|
||||
|
||||
161
.github/workflows/pr_style_bot.yml
vendored
Normal file
161
.github/workflows/pr_style_bot.yml
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
|
||||
name: PR Style Bot
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
check-permissions:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
|
||||
steps:
|
||||
- name: Check user permission
|
||||
id: check_user_permission
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const comment_user = context.payload.comment.user.login;
|
||||
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username: comment_user
|
||||
});
|
||||
|
||||
const authorized =
|
||||
permission.permission === 'admin' ||
|
||||
permission.permission === 'write';
|
||||
|
||||
console.log(
|
||||
`User ${comment_user} has permission level: ${permission.permission}, ` +
|
||||
`authorized: ${authorized} (admins & maintainers allowed)`
|
||||
);
|
||||
|
||||
core.setOutput('has_permission', authorized);
|
||||
|
||||
run-style-bot:
|
||||
needs: check-permissions
|
||||
if: needs.check-permissions.outputs.is_authorized == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v4
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
persist-credentials: true
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${PRNUMBER}"
|
||||
echo "Head Ref: ${HEADREF}"
|
||||
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Get Ruff Version from pre-commit-config.yaml
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Ruff
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check --fix
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local lfs.https://github.com/.locksverify false
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${HEADREF}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
||||
60
.github/workflows/quality.yml
vendored
60
.github/workflows/quality.yml
vendored
@@ -4,12 +4,12 @@ on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -19,7 +19,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
@@ -30,55 +32,27 @@ jobs:
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Ruff
|
||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||
env:
|
||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check
|
||||
run: ruff check --output-format=github
|
||||
|
||||
- name: Ruff format
|
||||
run: ruff format --diff
|
||||
|
||||
|
||||
poetry_check:
|
||||
name: Poetry check
|
||||
typos:
|
||||
name: Typos
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
|
||||
17
.github/workflows/test-docker-build.yml
vendored
17
.github/workflows/test-docker-build.yml
vendored
@@ -4,12 +4,12 @@ name: Test Dockerfiles
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
# Run only when DockerFile files are modified
|
||||
- "docker/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
@@ -22,6 +22,8 @@ jobs:
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
@@ -30,21 +32,18 @@ jobs:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
- name: Run step if only the files listed above change
|
||||
- name: Run step if only the files listed above change # zizmor: ignore[template-injection]
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: set-matrix
|
||||
env:
|
||||
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
run: |
|
||||
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
if: needs.get_changed_files.outputs.matrix != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -52,9 +51,13 @@ jobs:
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
76
.github/workflows/test.yml
vendored
76
.github/workflows/test.yml
vendored
@@ -2,14 +2,13 @@ name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
push:
|
||||
@@ -20,10 +19,16 @@ on:
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- "Makefile"
|
||||
- ".cache/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_VERSION: "0.6.0"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Pytest
|
||||
@@ -34,6 +39,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
@@ -41,25 +47,19 @@ jobs:
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
- name: Install lerobot (all extras)
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
@@ -74,28 +74,24 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --extras "test"
|
||||
- name: Install lerobot
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
uv run pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
@@ -110,6 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
@@ -117,20 +114,21 @@ jobs:
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
- name: Install lerobot (all extras)
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
uv venv
|
||||
uv sync --all-extras
|
||||
|
||||
- name: venv
|
||||
run: |
|
||||
echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
|
||||
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
|
||||
5
.github/workflows/trufflehog.yml
vendored
5
.github/workflows/trufflehog.yml
vendored
@@ -3,8 +3,7 @@ on:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
@@ -14,6 +13,8 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -49,6 +49,10 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# uv/poetry lock files
|
||||
poetry.lock
|
||||
uv.lock
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
exclude: ^(tests/data)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
python: python3.12
|
||||
repos:
|
||||
##### Style / Misc. #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
@@ -13,25 +14,34 @@ repos:
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.0
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
rev: v0.9.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.21.2
|
||||
rev: v8.24.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
hooks:
|
||||
- id: zizmor
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-c", "pyproject.toml"]
|
||||
additional_dependencies: ["bandit[toml]"]
|
||||
|
||||
@@ -129,38 +129,71 @@ Follow these steps to start contributing:
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
```bash
|
||||
poetry install --sync --extras "dev test"
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry install --sync --all-extras
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `uv`
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
@@ -195,7 +228,7 @@ Follow these steps to start contributing:
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
47
Makefile
47
Makefile
@@ -2,10 +2,10 @@
|
||||
|
||||
PYTHON_PATH := $(shell which python)
|
||||
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
# If uv is installed and a virtual environment exists, use it
|
||||
UV_CHECK := $(shell command -v uv)
|
||||
ifneq ($(UV_CHECK),)
|
||||
PYTHON_PATH := $(shell .venv/bin/python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
@@ -39,8 +39,8 @@ test-act-ete-train:
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--offline.steps=4 \
|
||||
--online.steps=0 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
@@ -76,8 +76,8 @@ test-diffusion-ete-train:
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--offline.steps=2 \
|
||||
--online.steps=0 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
@@ -106,8 +106,8 @@ test-tdmpc-ete-train:
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--offline.steps=2 \
|
||||
--online.steps=0 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
@@ -126,30 +126,3 @@ test-tdmpc-ete-eval:
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE)
|
||||
|
||||
# TODO(rcadene): fix online buffer to storing "task"
|
||||
# test-tdmpc-ete-train-with-online:
|
||||
# python lerobot/scripts/train.py \
|
||||
# --policy.type=tdmpc \
|
||||
# --env.type=pusht \
|
||||
# --env.obs_type=environment_state_agent_pos \
|
||||
# --env.episode_length=5 \
|
||||
# --dataset.repo_id=lerobot/pusht_keypoints \
|
||||
# --dataset.image_transforms.enable=true \
|
||||
# --dataset.episodes="[0]" \
|
||||
# --batch_size=2 \
|
||||
# --offline.steps=2 \
|
||||
# --online.steps=20 \
|
||||
# --online.rollout_n_episodes=2 \
|
||||
# --online.rollout_batch_size=2 \
|
||||
# --online.steps_between_rollouts=10 \
|
||||
# --online.buffer_capacity=1000 \
|
||||
# --online.env_seed=10000 \
|
||||
# --save_checkpoint=false \
|
||||
# --save_freq=10 \
|
||||
# --log_freq=1 \
|
||||
# --eval.use_async_envs=true \
|
||||
# --eval.n_episodes=1 \
|
||||
# --eval.batch_size=1 \
|
||||
# --device=$(DEVICE) \
|
||||
# --output_dir=tests/outputs/tdmpc_online/
|
||||
|
||||
@@ -210,7 +210,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
|
||||
@@ -1,33 +1,29 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
ARG PYTHON_VERSION
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git git-lfs \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -58,7 +58,7 @@ RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# Install poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | python - --version 1.8.5
|
||||
RUN curl -sSL https://install.python-poetry.org | python -
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
@@ -1,31 +1,24 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
|
||||
# Install apt dependencies
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git git-lfs \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
||||
@@ -335,7 +335,7 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## H. Visualize a dataset
|
||||
|
||||
@@ -344,7 +344,7 @@ If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you c
|
||||
echo ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
@@ -363,8 +363,6 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
@@ -378,8 +376,6 @@ python lerobot/scripts/train.py \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
@@ -416,4 +412,4 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
|
||||
463
examples/11_use_lekiwi.md
Normal file
463
examples/11_use_lekiwi.md
Normal file
@@ -0,0 +1,463 @@
|
||||
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Source the parts](#a-source-the-parts)
|
||||
- [B. Install software Pi](#b-install-software-on-pi)
|
||||
- [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop)
|
||||
- [D. Assemble the arms](#d-assembly)
|
||||
- [E. Calibrate](#e-calibration)
|
||||
- [F. Teleoperate](#f-teleoperate)
|
||||
- [G. Record a dataset](#g-record-a-dataset)
|
||||
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||
- [I. Replay an episode](#i-replay-an-episode)
|
||||
- [J. Train a policy](#j-train-a-policy)
|
||||
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371).
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install software on Pi
|
||||
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
|
||||
|
||||
### Install OS
|
||||
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
|
||||
|
||||
### Setup SSH
|
||||
After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
On your Raspberry Pi:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## C. Install LeRobot on laptop
|
||||
If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi.
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
On your computer:
|
||||
|
||||
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
|
||||
#### 2. Restart shell
|
||||
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
#### 3. Create and activate a fresh conda environment for lerobot
|
||||
|
||||
<details>
|
||||
<summary><strong>Video install instructions</strong></summary>
|
||||
|
||||
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
#### 4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||
|
||||
# D. Assembly
|
||||
|
||||
First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base.
|
||||
|
||||
## SO100 Arms
|
||||
### Configure motors
|
||||
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||
|
||||
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
|
||||
|
||||
### Assemble arms
|
||||
[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms)
|
||||
|
||||
## Mobile base (LeKiwi)
|
||||
[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi)
|
||||
|
||||
### Update config
|
||||
Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script.
|
||||
|
||||
#### a. Run the script to find port
|
||||
|
||||
<details>
|
||||
<summary><strong>Video finding port</strong></summary>
|
||||
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||
</details>
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "172.17.133.91"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480),
|
||||
"mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
```
|
||||
|
||||
# E. Calibration
|
||||
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
|
||||
|
||||
|
||||
### Calibrate follower arm (on mobile base)
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
### Calibrate leader arm
|
||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script (on your laptop/pc) to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
# F. Teleoperate
|
||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=remote_robot
|
||||
```
|
||||
|
||||
Then on your laptop, also run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=teleoperate \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
|------------|-------------------|-----------------------|
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|
||||
|
||||
| Key | Action |
|
||||
|------|--------------------------------|
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
|
||||
> [!TIP]
|
||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||
|
||||
## Troubleshoot communication
|
||||
|
||||
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
|
||||
|
||||
### 1. Verify IP Address Configuration
|
||||
Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line):
|
||||
```bash
|
||||
hostname -I
|
||||
```
|
||||
|
||||
### 2. Check if Pi is reachable from laptop/pc
|
||||
Try pinging the Raspberry Pi from your laptop:
|
||||
```bach
|
||||
ping <your_pi_ip_address>
|
||||
```
|
||||
|
||||
If the ping fails:
|
||||
- Ensure the Pi is powered on and connected to the same network.
|
||||
- Check if SSH is enabled on the Pi.
|
||||
|
||||
### 3. Try SSH connection
|
||||
If you can't SSH into the Pi, it might not be properly connected. Use:
|
||||
```bash
|
||||
ssh <your_pi_user_name>@<your_pi_ip_address>
|
||||
```
|
||||
If you get a connection error:
|
||||
- Ensure SSH is enabled on the Pi by running:
|
||||
```bash
|
||||
sudo raspi-config
|
||||
```
|
||||
Then navigate to: **Interfacing Options -> SSH** and enable it.
|
||||
|
||||
### 4. Same config file
|
||||
Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same.
|
||||
|
||||
# G. Record a dataset
|
||||
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
# H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/lekiwi_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/lekiwi_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
# I. Replay an episode
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/lekiwi_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_lekiwi_test \
|
||||
--job_name=act_lekiwi_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Drive to the red block and pick it up" \
|
||||
--control.repo_id=${HF_USER}/eval_act_lekiwi_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`).
|
||||
@@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
@@ -256,7 +256,7 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
@@ -284,8 +284,6 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
@@ -299,8 +297,6 @@ python lerobot/scripts/train.py \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
|
||||
@@ -85,9 +85,8 @@ def main():
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
@@ -161,13 +161,13 @@ python lerobot/scripts/train.py \
|
||||
```
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
|
||||
You could double the number of steps of the previous run with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--offline.steps=200000
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
## Outputs of a run
|
||||
@@ -175,12 +175,16 @@ In the output directory, there will be a folder called `checkpoints` with the fo
|
||||
```bash
|
||||
outputs/train/run_resumption/checkpoints
|
||||
├── 000100 # checkpoint_dir for training step 100
|
||||
│ ├── pretrained_model
|
||||
│ │ ├── config.json # pretrained policy config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ ├── train_config.json # train config
|
||||
│ │ └── README.md # model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
│ ├── pretrained_model/
|
||||
│ │ ├── config.json # policy config
|
||||
│ │ ├── model.safetensors # policy weights
|
||||
│ │ └── train_config.json # train config
|
||||
│ └── training_state/
|
||||
│ ├── optimizer_param_groups.json # optimizer param groups
|
||||
│ ├── optimizer_state.safetensors # optimizer state
|
||||
│ ├── rng_state.safetensors # rng states
|
||||
│ ├── scheduler_state.json # scheduler state
|
||||
│ └── training_step.json # training step
|
||||
├── 000200
|
||||
└── last -> 000200 # symlink to the last available checkpoint
|
||||
```
|
||||
@@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--offline.steps=200000 # <- you can change some training parameters
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
```
|
||||
|
||||
#### Fine-tuning
|
||||
|
||||
@@ -36,9 +36,14 @@ Using `pip`:
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Or using `poetry`:
|
||||
Using `poetry`:
|
||||
```bash
|
||||
poetry install --sync --extras "dynamixel"
|
||||
poetry sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
Using `uv`:
|
||||
```bash
|
||||
uv sync --extra "dynamixel"
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
@@ -393,7 +398,7 @@ And here are the corresponding positions for the leader arm:
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
@@ -621,7 +626,7 @@ Finally, run this code to instantiate and connectyour camera:
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
camera_config = OpenCVCameraConfig(camera_index=0)
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
@@ -658,18 +663,20 @@ camera.disconnect()
|
||||
|
||||
**Instantiate your robot with cameras**
|
||||
|
||||
Additionaly, you can set up your robot to work with your cameras.
|
||||
Additionally, you can set up your robot to work with your cameras.
|
||||
|
||||
Modify the following Python code with the appropriate camera names and configurations:
|
||||
```python
|
||||
robot = ManipulatorRobot(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
KochRobotConfig(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
)
|
||||
robot.connect()
|
||||
```
|
||||
@@ -706,7 +713,7 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||
```
|
||||
|
||||
It contains
|
||||
@@ -763,7 +770,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. Also you will need to manually delete the dataset directory to start recording from scratch.
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
5. Set the flow of data recording using command line arguments:
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
@@ -818,8 +825,8 @@ It contains:
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
||||
@@ -839,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
|
||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
### b. Advices for recording dataset
|
||||
### b. Advice for recording dataset
|
||||
|
||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||
|
||||
@@ -878,8 +885,6 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Note: You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
|
||||
## 4. Train a policy on your data
|
||||
@@ -897,8 +902,6 @@ python lerobot/scripts/train.py \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Note: You might need to add `--dataset.local_files_only=true` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
|
||||
@@ -98,7 +98,7 @@ python lerobot/scripts/control_robot.py \
|
||||
```
|
||||
This is equivalent to running `stretch_robot_home.py`
|
||||
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first.
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
|
||||
|
||||
**Teleoperate**
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
@@ -2,7 +2,7 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
@@ -172,10 +172,10 @@ python lerobot/scripts/control_robot.py \
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
||||
|
||||
@@ -75,9 +75,9 @@ def main():
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
loss_cumsum += loss.item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
|
||||
@@ -2,9 +2,10 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
@@ -89,9 +90,9 @@ def calculate_coverage(zarr_data):
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
@@ -117,7 +118,7 @@ def calculate_coverage(zarr_data):
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
@@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
@@ -148,6 +149,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||
image = image.astype(np.uint8)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
@@ -175,28 +180,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
idx = i + (frame_idx < num_frames - 1)
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
"action": action[i],
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
"next.reward": reward[idx : idx + 1],
|
||||
"next.success": success[idx : idx + 1],
|
||||
"task": PUSHT_TASK,
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
frame["observation.state"] = agent_pos[i]
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
frame["observation.environment_state"] = keypoints[i]
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
frame["observation.image"] = image[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -218,5 +225,5 @@ if __name__ == "__main__":
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
# breakpoint()
|
||||
|
||||
@@ -1,6 +1,32 @@
|
||||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
|
||||
54
lerobot/common/datasets/backward_compatibility.py
Normal file
54
lerobot/common/datasets/backward_compatibility.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -13,202 +13,164 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
Lower the power for less number of samples.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
For default arguments, we have:
|
||||
- from 1 to ~500, num_samples=100
|
||||
- at 1000, num_samples=177
|
||||
- at 2000, num_samples=299
|
||||
- at 5000, num_samples=594
|
||||
- at 10000, num_samples=1000
|
||||
- at 20000, num_samples=1681
|
||||
"""
|
||||
if dataset_len < min_num_samples:
|
||||
min_num_samples = dataset_len
|
||||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
def sample_indices(data_len: int) -> list[int]:
|
||||
num_samples = estimate_num_samples(data_len)
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||
_, height, width = img.shape
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
if max(width, height) < max_size_threshold:
|
||||
# no downsampling needed
|
||||
return img
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
|
||||
images[i] = img
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
return stats_patterns
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
while counts.ndim < means.ndim:
|
||||
counts = np.expand_dims(counts, axis=-1)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
# Compute the weighted mean
|
||||
weighted_means = means * counts
|
||||
total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# Compute the variance using the parallel algorithm
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
For instance:
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_mean = (mean of all data, weighted by counts)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
_assert_type_and_shape(stats_list)
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
@@ -83,15 +83,18 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
local_files_only=cfg.dataset.local_files_only,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
@@ -104,7 +107,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index , indent=2)}"
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
|
||||
@@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
)
|
||||
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
if range_check:
|
||||
max_ = image_array.max().item()
|
||||
min_ = image_array.min().item()
|
||||
if max_ > 1.0 or min_ < 0.0:
|
||||
raise ValueError(
|
||||
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||
"provide a uint8 image with values in the range [0, 255]."
|
||||
)
|
||||
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
img = image_array_to_pil_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
|
||||
@@ -13,50 +13,57 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
serialize_dict,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
write_parquet,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -66,9 +73,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -76,19 +81,36 @@ class LeRobotDatasetMetadata:
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -98,21 +120,16 @@ class LeRobotDatasetMetadata:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
@@ -202,54 +219,65 @@ class LeRobotDatasetMetadata:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
otherwise return None.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
return self.task_to_task_index.get(task, None)
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
def add_task(self, task: str):
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
||||
# ep_stats = serialize_dict(ep_stats)
|
||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
@@ -259,8 +287,6 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
@@ -286,7 +312,7 @@ class LeRobotDatasetMetadata:
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
@@ -304,6 +330,7 @@ class LeRobotDatasetMetadata:
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
@@ -313,12 +340,13 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.local_files_only = True
|
||||
obj.revision = None
|
||||
return obj
|
||||
|
||||
|
||||
@@ -331,8 +359,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -342,7 +371,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||
internet connection), in that case, use local_files_only=True.
|
||||
internet connection).
|
||||
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||
@@ -362,7 +391,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
- info contains various information about the dataset like shapes, keys, fps etc.
|
||||
- stats stores the dataset statistics of the different modalities for normalization
|
||||
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||
task-conditionned training.
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
|
||||
@@ -424,24 +453,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. Defaults to current codebase version tag.
|
||||
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
||||
are already present in the local cache, this will be faster. However, files loaded might not
|
||||
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
||||
False.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
||||
will be made. Defaults to False.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -450,64 +483,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
logging.warning(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||
"Consolidating first."
|
||||
)
|
||||
self.consolidate()
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
repo_id=self.repo_id,
|
||||
private=private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
if branch:
|
||||
hub_api.create_branch(
|
||||
repo_id=self.repo_id,
|
||||
branch=branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
upload_kwargs = {
|
||||
"repo_id": self.repo_id,
|
||||
"folder_path": self.root,
|
||||
"repo_type": "dataset",
|
||||
"revision": branch,
|
||||
"allow_patterns": allow_patterns,
|
||||
"ignore_patterns": ignore_patterns,
|
||||
}
|
||||
if upload_large_folder:
|
||||
hub_api.upload_large_folder(**upload_kwargs)
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -517,11 +578,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.meta._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
@@ -535,17 +595,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
@@ -557,7 +623,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
@@ -624,7 +698,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
@@ -654,8 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -691,10 +764,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
}
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
ep_buffer["task"] = []
|
||||
for key in self.features:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
@@ -716,25 +792,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
# Convert torch to numpy if needed
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
validate_frame(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
raise ValueError(key)
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
if key not in self.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -742,80 +828,95 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
@@ -884,38 +985,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -944,7 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
@@ -954,14 +1023,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.hf_dataset = obj.create_hf_dataset()
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
@@ -986,12 +1049,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else LEROBOT_HOME
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
@@ -1004,7 +1066,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
@@ -1032,7 +1093,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Send warning if raw_dir isn't well formated
|
||||
# Send warning if raw_dir isn't well formatted
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
|
||||
@@ -68,11 +68,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far appart.
|
||||
# are too far apart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
@@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
# sanity check the video paths are well formatted
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
@@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
# sanity check the video path is well formatted
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -27,14 +27,21 @@ from typing import Any
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import packaging.version
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
@@ -42,6 +49,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
@@ -112,17 +120,26 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset.to_parquet(fpath)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
@@ -153,6 +170,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
@@ -160,29 +181,76 @@ def load_info(local_dir: Path) -> dict:
|
||||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
@@ -201,77 +269,95 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
def is_valid_version(version: str) -> bool:
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
except packaging.version.InvalidVersion:
|
||||
return False
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
repo_id: str,
|
||||
version_to_check: str | packaging.version.Version,
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
else version_to_check
|
||||
)
|
||||
v_current = (
|
||||
packaging.version.parse(current_version)
|
||||
if not isinstance(current_version, packaging.version.Version)
|
||||
else current_version
|
||||
)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(packaging.version.InvalidVersion):
|
||||
repo_versions.append(packaging.version.parse(ref))
|
||||
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
if lower_major:
|
||||
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
||||
|
||||
upper_versions = [v for v in hub_versions if v > target_version]
|
||||
assert len(upper_versions) > 0
|
||||
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
@@ -283,11 +369,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
@@ -358,88 +453,85 @@ def create_empty_dataset_info(
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lengths),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
episode_data_index: dict[str, np.ndarray],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
Returns:
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
"timestamps and episode_indices should have the same shape. "
|
||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||
)
|
||||
|
||||
# Consecutive differences
|
||||
diffs = np.diff(timestamps)
|
||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Check if all remaining diffs are within tolerance
|
||||
if not np.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
original_indices = np.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
"episode_index": episode_indices[idx].item()
|
||||
if hasattr(episode_indices[idx], "item")
|
||||
else episode_indices[idx],
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
This might be due to synchronization issues during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
@@ -604,3 +696,118 @@ class IterableNamespace(SimpleNamespace):
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
# spellchecker:off
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
@@ -856,6 +857,7 @@ DATASETS = {
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
# spellchecker:on
|
||||
|
||||
|
||||
def batch_convert():
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
get_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
@@ -443,7 +443,7 @@ def convert_dataset(
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
73
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
73
lerobot/common/datasets/v21/_remove_language_instruction.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import get_dataset_config_info
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
hub_api = HfApi()
|
||||
|
||||
|
||||
def fix_dataset(repo_id: str) -> str:
|
||||
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||
return f"{repo_id}: skipped (not in {V20})."
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
commit_info = hub_api.upload_file(
|
||||
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||
path_in_repo=INFO_PATH,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V20,
|
||||
commit_message="Remove 'language_instruction'",
|
||||
create_pr=True,
|
||||
)
|
||||
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||
|
||||
|
||||
def batch_fix():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
status = fix_dataset(repo_id)
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_fix()
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
hub_api = HfApi()
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||
status = f"{repo_id}: success (already in {V21})."
|
||||
else:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
100
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
100
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
85
lerobot/common/datasets/v21/convert_stats.py
Normal file
85
lerobot/common/datasets/v21/convert_stats.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
@@ -69,11 +69,11 @@ def decode_video_frames_torchvision(
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
|
||||
|
||||
@@ -37,12 +37,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not intalled
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
|
||||
PRETRAINED_MODEL = "pretrained_model"
|
||||
TRAINING_STATE = "training_state.pth"
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = PRETRAINED_MODEL
|
||||
training_state_file_name = TRAINING_STATE
|
||||
|
||||
def __init__(self, cfg: TrainPipelineConfig):
|
||||
self._cfg = cfg
|
||||
self.log_dir = cfg.output_dir
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.job_name = cfg.job_name
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
run_offline = not cfg.wandb.enable or not cfg.wandb.project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=cfg.wandb.project,
|
||||
entity=cfg.wandb.entity,
|
||||
name=self.job_name,
|
||||
notes=cfg.wandb.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=asdict(self._cfg),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
register_features_types()
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full config for the env configuration.
|
||||
self._cfg.save_pretrained(save_dir)
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {}
|
||||
training_state["step"] = train_step
|
||||
training_state.update(get_global_random_state())
|
||||
if optimizer is not None:
|
||||
training_state["optimizer"] = optimizer.state_dict()
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
identifier: str,
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
|
||||
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
|
||||
self.last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
|
||||
def register_features_types():
|
||||
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
|
||||
draccus.encode.register(FeatureType, lambda x: x.name)
|
||||
|
||||
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
|
||||
draccus.encode.register(NormalizationMode, lambda x: x.name)
|
||||
@@ -14,15 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.logger import TRAINING_STATE
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@@ -40,22 +36,5 @@ def make_optimizer_and_scheduler(
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
|
||||
return the global training step.
|
||||
"""
|
||||
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
|
||||
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
|
||||
# Small HACK to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], optimizer, scheduler
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -68,3 +92,27 @@ class SGDConfig(OptimizerConfig):
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
||||
)
|
||||
loaded_state_dict["param_groups"] = param_groups
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
|
||||
@@ -1,11 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -89,3 +109,14 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
@@ -64,7 +64,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
|
||||
@@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
loss = l1_loss
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
@@ -409,9 +409,9 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
assert "action" in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
|
||||
@@ -68,7 +68,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -99,7 +99,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
@@ -221,7 +221,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -77,17 +78,29 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
@@ -141,6 +154,7 @@ class Normalize(nn.Module):
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required librairies.
|
||||
and install the required libraries.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
|
||||
@@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
@@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
loss = losses.mean()
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
return loss_dict
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
|
||||
@@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""_summary_
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
Args:
|
||||
batch (dict[str, Tensor]): _description_
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
||||
is a Tensor, all other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
||||
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
@@ -165,7 +165,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
||||
raise ValueError(
|
||||
|
||||
@@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
@@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
@@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
return loss, info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
@@ -595,9 +594,9 @@ class TDMPCTOLD(nn.Module):
|
||||
|
||||
self.apply(_apply_fn)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
assert isinstance(m[-1], nn.Linear), (
|
||||
"Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
)
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
@@ -184,7 +184,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
@@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
loss = loss_dict.pop("loss")
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
@@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -482,10 +482,10 @@ class VQBeTHead(nn.Module):
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, **kwargs) -> dict:
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
# we calculate N and T side parallelly. Thus, the dimensions would be
|
||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
@@ -772,7 +772,7 @@ class VqVae(nn.Module):
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
|
||||
as well as functions to help BeT training part in training phase 2.
|
||||
"""
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
|
||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
@@ -203,9 +203,9 @@ class GPT(nn.Module):
|
||||
def forward(self, input, targets=None):
|
||||
device = input.device
|
||||
b, t, d = input.size()
|
||||
assert (
|
||||
t <= self.config.gpt_block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
assert t <= self.config.gpt_block_size, (
|
||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
)
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
@@ -273,10 +273,10 @@ class GPT(nn.Module):
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert (
|
||||
len(param_dict.keys() - union_params) == 0
|
||||
), "parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
)
|
||||
|
||||
decay = [param_dict[pn] for pn in sorted(decay)]
|
||||
@@ -289,7 +289,7 @@ class GPT(nn.Module):
|
||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
|
||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
|
||||
@@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
|
||||
# and the network should be able to reconstruct
|
||||
|
||||
if quantize_dim < self.num_quantizers:
|
||||
assert (
|
||||
self.quantize_dropout > 0.0
|
||||
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
||||
assert self.quantize_dropout > 0.0, (
|
||||
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
||||
)
|
||||
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
||||
|
||||
# get ready for gathering
|
||||
@@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
|
||||
all_indices = []
|
||||
|
||||
if return_loss:
|
||||
assert not torch.any(
|
||||
indices == -1
|
||||
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
assert not torch.any(indices == -1), (
|
||||
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
)
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
@@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
|
||||
# only calculate orthogonal loss for the activated codes for this batch
|
||||
|
||||
if self.orthogonal_reg_active_codes_only:
|
||||
assert not (
|
||||
is_multiheaded and self.separate_codebook_per_head
|
||||
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
||||
assert not (is_multiheaded and self.separate_codebook_per_head), (
|
||||
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
||||
)
|
||||
unique_code_ids = torch.unique(embed_ind)
|
||||
codebook = codebook[:, unique_code_ids]
|
||||
|
||||
@@ -999,9 +999,9 @@ def gumbel_sample(
|
||||
ind = sampling_logits.argmax(dim=dim)
|
||||
one_hot = F.one_hot(ind, size).type(dtype)
|
||||
|
||||
assert not (
|
||||
reinmax and not straight_through
|
||||
), "reinmax can only be turned on if using straight through gumbel softmax"
|
||||
assert not (reinmax and not straight_through), (
|
||||
"reinmax can only be turned on if using straight through gumbel softmax"
|
||||
)
|
||||
|
||||
if not straight_through or temperature <= 0.0 or not training:
|
||||
return ind, one_hot
|
||||
@@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
|
||||
self.gumbel_sample = gumbel_sample
|
||||
self.sample_codebook_temp = sample_codebook_temp
|
||||
|
||||
assert not (
|
||||
use_ddp and num_codebooks > 1 and kmeans_init
|
||||
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
|
||||
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
)
|
||||
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
@@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
# calculate distributed variance
|
||||
|
||||
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
||||
distributed.all_reduce(variance_numer)
|
||||
batch_variance = variance_numer / num_vectors
|
||||
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
||||
distributed.all_reduce(variance_number)
|
||||
batch_variance = variance_number / num_vectors
|
||||
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
|
||||
|
||||
@@ -60,15 +60,13 @@ class RecordControlConfig(ControlConfig):
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
|
||||
run_compute_stats: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
@@ -83,9 +81,6 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -130,9 +125,12 @@ class ReplayControlConfig(ControlConfig):
|
||||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("remote_robot")
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,6 @@ from functools import cache
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
@@ -33,7 +32,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
@@ -183,6 +182,7 @@ def record_episode(
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
@@ -195,6 +195,7 @@ def record_episode(
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
|
||||
@@ -210,6 +211,7 @@ def control_loop(
|
||||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -224,6 +226,9 @@ def control_loop(
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
@@ -248,7 +253,7 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
@@ -270,24 +275,18 @@ def control_loop(
|
||||
break
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s):
|
||||
def reset_environment(robot, events, reset_time_s, fps):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
# TODO(alibets): allow for teleop during reset
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=reset_time_s,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=True,
|
||||
)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
|
||||
@@ -242,7 +242,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -610,7 +610,7 @@ class DynamixelMotorsBus:
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class DriveMode(enum.Enum):
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
@@ -591,7 +591,7 @@ class FeetechMotorsBus:
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Substract the homing offsets to come back to actual motor range of values
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
@@ -632,7 +632,7 @@ class FeetechMotorsBus:
|
||||
track["prev"][idx] = values[i]
|
||||
continue
|
||||
|
||||
# Detect a full rotation occured
|
||||
# Detect a full rotation occurred
|
||||
if abs(track["prev"][idx] - values[i]) > 2048:
|
||||
# Position went below 0 and got reset to 4095
|
||||
if track["prev"][idx] < values[i]:
|
||||
|
||||
@@ -514,3 +514,86 @@ class StretchRobotConfig(RobotConfig):
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "192.168.0.193"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"front": OpenCVCameraConfig(
|
||||
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
||||
),
|
||||
"wrist": OpenCVCameraConfig(
|
||||
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
"forward": "w",
|
||||
"backward": "s",
|
||||
"left": "a",
|
||||
"right": "d",
|
||||
"rotate_left": "z",
|
||||
"rotate_right": "x",
|
||||
# Speed control
|
||||
"speed_up": "r",
|
||||
"speed_down": "f",
|
||||
# quit teleop
|
||||
"quit": "q",
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
@@ -87,7 +87,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
@@ -115,7 +115,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
||||
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
calib_idx = arm.motor_names.index("gripper")
|
||||
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
||||
|
||||
|
||||
@@ -443,7 +443,7 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
|
||||
210
lerobot/common/robot_devices/robots/lekiwi_remote.py
Normal file
210
lerobot/common/robot_devices/robots/lekiwi_remote.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import base64
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import zmq
|
||||
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
|
||||
|
||||
|
||||
def setup_zmq_sockets(config):
|
||||
context = zmq.Context()
|
||||
cmd_socket = context.socket(zmq.PULL)
|
||||
cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
cmd_socket.bind(f"tcp://*:{config.port}")
|
||||
|
||||
video_socket = context.socket(zmq.PUSH)
|
||||
video_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
video_socket.bind(f"tcp://*:{config.video_port}")
|
||||
|
||||
return context, cmd_socket, video_socket
|
||||
|
||||
|
||||
def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||
while not stop_event.is_set():
|
||||
local_dict = {}
|
||||
for name, cam in cameras.items():
|
||||
frame = cam.async_read()
|
||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||
if ret:
|
||||
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
||||
else:
|
||||
local_dict[name] = ""
|
||||
with images_lock:
|
||||
latest_images_dict.update(local_dict)
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
"""
|
||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||
if not found, runs manual calibration and saves the result.
|
||||
"""
|
||||
calib_dir = Path(calib_dir_str)
|
||||
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||
calib_file = calib_dir / "main_follower.json"
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
except ImportError:
|
||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||
return
|
||||
|
||||
if calib_file.exists():
|
||||
with open(calib_file) as f:
|
||||
calibration = json.load(f)
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
try:
|
||||
motors_bus.set_calibration(calibration)
|
||||
print("[INFO] Applied calibration for follower arm.")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Could not apply calibration: {e}")
|
||||
|
||||
|
||||
def run_lekiwi(robot_config):
|
||||
"""
|
||||
Runs the LeKiwi robot:
|
||||
- Sets up cameras and connects them.
|
||||
- Initializes the follower arm motors.
|
||||
- Calibrates the follower arm if necessary.
|
||||
- Creates ZeroMQ sockets for receiving commands and streaming observations.
|
||||
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
|
||||
"""
|
||||
# Import helper functions and classes
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||
|
||||
# Initialize cameras from the robot configuration.
|
||||
cameras = make_cameras_from_configs(robot_config.cameras)
|
||||
for cam in cameras.values():
|
||||
cam.connect()
|
||||
|
||||
# Initialize the motors bus using the follower arm configuration.
|
||||
motor_config = robot_config.follower_arms.get("main")
|
||||
if motor_config is None:
|
||||
print("[ERROR] Follower arm 'main' configuration not found.")
|
||||
return
|
||||
motors_bus = FeetechMotorsBus(motor_config)
|
||||
motors_bus.connect()
|
||||
|
||||
# Calibrate the follower arm.
|
||||
calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
|
||||
|
||||
# Create the LeKiwi robot instance.
|
||||
robot = LeKiwi(motors_bus)
|
||||
|
||||
# Define the expected arm motor IDs.
|
||||
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||
|
||||
# Disable torque for each arm motor.
|
||||
for motor in arm_motor_ids:
|
||||
motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
|
||||
|
||||
# Set up ZeroMQ sockets.
|
||||
context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
|
||||
|
||||
# Start the camera capture thread.
|
||||
latest_images_dict = {}
|
||||
images_lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
cam_thread = threading.Thread(
|
||||
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||
)
|
||||
cam_thread.start()
|
||||
|
||||
last_cmd_time = time.time()
|
||||
print("LeKiwi robot server started. Waiting for commands...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start_time = time.time()
|
||||
|
||||
# Process incoming commands (non-blocking).
|
||||
while True:
|
||||
try:
|
||||
msg = cmd_socket.recv_string(zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
break
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
# Process arm position commands.
|
||||
if "arm_positions" in data:
|
||||
arm_positions = data["arm_positions"]
|
||||
if not isinstance(arm_positions, list):
|
||||
print(f"[ERROR] Invalid arm_positions: {arm_positions}")
|
||||
elif len(arm_positions) < len(arm_motor_ids):
|
||||
print(
|
||||
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
||||
)
|
||||
else:
|
||||
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
||||
motors_bus.write("Goal_Position", pos, motor)
|
||||
# Process wheel (base) commands.
|
||||
if "raw_velocity" in data:
|
||||
raw_command = data["raw_velocity"]
|
||||
# Expect keys: "left_wheel", "back_wheel", "right_wheel".
|
||||
command_speeds = [
|
||||
int(raw_command.get("left_wheel", 0)),
|
||||
int(raw_command.get("back_wheel", 0)),
|
||||
int(raw_command.get("right_wheel", 0)),
|
||||
]
|
||||
robot.set_velocity(command_speeds)
|
||||
last_cmd_time = time.time()
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Parsing message failed: {e}")
|
||||
|
||||
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
|
||||
now = time.time()
|
||||
if now - last_cmd_time > 0.5:
|
||||
robot.stop()
|
||||
last_cmd_time = now
|
||||
|
||||
# Read current wheel speeds from the robot.
|
||||
current_velocity = robot.read_velocity()
|
||||
|
||||
# Read the follower arm state from the motors bus.
|
||||
follower_arm_state = []
|
||||
for motor in arm_motor_ids:
|
||||
try:
|
||||
pos = motors_bus.read("Present_Position", motor)
|
||||
# Convert the position to a float (or use as is if already numeric).
|
||||
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
||||
|
||||
# Get the latest camera images.
|
||||
with images_lock:
|
||||
images_dict_copy = dict(latest_images_dict)
|
||||
|
||||
# Build the observation dictionary.
|
||||
observation = {
|
||||
"images": images_dict_copy,
|
||||
"present_speed": current_velocity,
|
||||
"follower_arm_state": follower_arm_state,
|
||||
}
|
||||
# Send the observation over the video socket.
|
||||
video_socket.send_string(json.dumps(observation))
|
||||
|
||||
# Ensure a short sleep to avoid overloading the CPU.
|
||||
elapsed = time.time() - loop_start_time
|
||||
time.sleep(
|
||||
max(0.033 - elapsed, 0)
|
||||
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down LeKiwi server.")
|
||||
finally:
|
||||
stop_event.set()
|
||||
cam_thread.join()
|
||||
robot.stop()
|
||||
motors_bus.disconnect()
|
||||
cmd_socket.close()
|
||||
video_socket.close()
|
||||
context.term()
|
||||
@@ -44,7 +44,7 @@ class ManipulatorRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
"""This class allows to control any manipulator robot of various number of motors.
|
||||
|
||||
Non exaustive list of robots:
|
||||
Non exhaustive list of robots:
|
||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
|
||||
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
|
||||
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
||||
@@ -55,7 +55,7 @@ class ManipulatorRobot:
|
||||
robot = ManipulatorRobot(KochRobotConfig())
|
||||
```
|
||||
|
||||
Example of overwritting motors during instantiation:
|
||||
Example of overwriting motors during instantiation:
|
||||
```python
|
||||
# Defines how to communicate with the motors of the leader and follower arms
|
||||
leader_arms = {
|
||||
@@ -90,7 +90,7 @@ class ManipulatorRobot:
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
Example of overwritting cameras during instantiation:
|
||||
Example of overwriting cameras during instantiation:
|
||||
```python
|
||||
# Defines how to communicate with 2 cameras connected to the computer.
|
||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||
@@ -229,7 +229,7 @@ class ManipulatorRobot:
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
|
||||
# We assume that at connection time, arms are in a rest position, and torque can
|
||||
@@ -246,7 +246,7 @@ class ManipulatorRobot:
|
||||
self.set_koch_robot_preset()
|
||||
elif self.robot_type == "aloha":
|
||||
self.set_aloha_robot_preset()
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
self.set_so100_robot_preset()
|
||||
|
||||
# Enable torque on all motors of the follower arms
|
||||
@@ -299,7 +299,7 @@ class ManipulatorRobot:
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
@@ -348,7 +348,7 @@ class ManipulatorRobot:
|
||||
set_operating_mode_(self.follower_arms[name])
|
||||
|
||||
# Set better PID values to close the gap between recorded states and actions
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
|
||||
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
||||
@@ -500,7 +500,7 @@ class ManipulatorRobot:
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
@@ -540,7 +540,7 @@ class ManipulatorRobot:
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
# Populate output dictionaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
|
||||
691
lerobot/common/robot_devices/robots/mobile_manipulator.py
Normal file
691
lerobot/common/robot_devices/robots/mobile_manipulator.py
Normal file
@@ -0,0 +1,691 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
# Only import if there's a valid X server or if we're not on a Pi
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
print("No DISPLAY set. Skipping pynput import.")
|
||||
raise ImportError("pynput blocked intentionally due to no display.")
|
||||
|
||||
from pynput import keyboard
|
||||
except ImportError:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
except Exception as e:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
print(f"Could not import pynput: {e}")
|
||||
|
||||
|
||||
class MobileManipulator:
|
||||
"""
|
||||
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
||||
The robot includes a three omniwheel mobile base and a remote follower arm.
|
||||
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
||||
forwarded to the remote follower arm (after applying a safety clamp).
|
||||
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LeKiwiRobotConfig):
|
||||
"""
|
||||
Expected keys in config:
|
||||
- ip, port, video_port for the remote connection.
|
||||
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
||||
"""
|
||||
self.robot_type = config.type
|
||||
self.config = config
|
||||
self.remote_ip = config.ip
|
||||
self.remote_port = config.port
|
||||
self.remote_port_video = config.video_port
|
||||
self.calibration_dir = Path(self.config.calibration_dir)
|
||||
self.logs = {}
|
||||
|
||||
self.teleop_keys = self.config.teleop_keys
|
||||
|
||||
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||
|
||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
self.last_frames = {}
|
||||
self.last_present_speed = {}
|
||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
{"xy": 0.1, "theta": 30}, # slow
|
||||
{"xy": 0.2, "theta": 60}, # medium
|
||||
{"xy": 0.3, "theta": 90}, # fast
|
||||
]
|
||||
self.speed_index = 0 # Start at slow
|
||||
|
||||
# ZeroMQ context and sockets.
|
||||
self.context = None
|
||||
self.cmd_socket = None
|
||||
self.video_socket = None
|
||||
|
||||
# Keyboard state for base teleoperation.
|
||||
self.running = True
|
||||
self.pressed_keys = {
|
||||
"forward": False,
|
||||
"backward": False,
|
||||
"left": False,
|
||||
"right": False,
|
||||
"rotate_left": False,
|
||||
"rotate_right": False,
|
||||
}
|
||||
|
||||
if PYNPUT_AVAILABLE:
|
||||
print("pynput is available - enabling local keyboard listener.")
|
||||
self.listener = keyboard.Listener(
|
||||
on_press=self.on_press,
|
||||
on_release=self.on_release,
|
||||
)
|
||||
self.listener.start()
|
||||
else:
|
||||
print("pynput not available - skipping local keyboard listener.")
|
||||
self.listener = None
|
||||
|
||||
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
follower_arm_names = [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
]
|
||||
observations = ["x_mm", "y_mm", "theta"]
|
||||
combined_names = follower_arm_names + observations
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(combined_names),),
|
||||
"names": combined_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(combined_names),),
|
||||
"names": combined_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
@property
|
||||
def available_arms(self):
|
||||
available = []
|
||||
for name in self.leader_arms:
|
||||
available.append(get_arm_id(name, "leader"))
|
||||
for name in self.follower_arms:
|
||||
available.append(get_arm_id(name, "follower"))
|
||||
return available
|
||||
|
||||
def on_press(self, key):
|
||||
try:
|
||||
# Movement
|
||||
if key.char == self.teleop_keys["forward"]:
|
||||
self.pressed_keys["forward"] = True
|
||||
elif key.char == self.teleop_keys["backward"]:
|
||||
self.pressed_keys["backward"] = True
|
||||
elif key.char == self.teleop_keys["left"]:
|
||||
self.pressed_keys["left"] = True
|
||||
elif key.char == self.teleop_keys["right"]:
|
||||
self.pressed_keys["right"] = True
|
||||
elif key.char == self.teleop_keys["rotate_left"]:
|
||||
self.pressed_keys["rotate_left"] = True
|
||||
elif key.char == self.teleop_keys["rotate_right"]:
|
||||
self.pressed_keys["rotate_right"] = True
|
||||
|
||||
# Quit teleoperation
|
||||
elif key.char == self.teleop_keys["quit"]:
|
||||
self.running = False
|
||||
return False
|
||||
|
||||
# Speed control
|
||||
elif key.char == self.teleop_keys["speed_up"]:
|
||||
self.speed_index = min(self.speed_index + 1, 2)
|
||||
print(f"Speed index increased to {self.speed_index}")
|
||||
elif key.char == self.teleop_keys["speed_down"]:
|
||||
self.speed_index = max(self.speed_index - 1, 0)
|
||||
print(f"Speed index decreased to {self.speed_index}")
|
||||
|
||||
except AttributeError:
|
||||
# e.g., if key is special like Key.esc
|
||||
if key == keyboard.Key.esc:
|
||||
self.running = False
|
||||
return False
|
||||
|
||||
def on_release(self, key):
|
||||
try:
|
||||
if hasattr(key, "char"):
|
||||
if key.char == self.teleop_keys["forward"]:
|
||||
self.pressed_keys["forward"] = False
|
||||
elif key.char == self.teleop_keys["backward"]:
|
||||
self.pressed_keys["backward"] = False
|
||||
elif key.char == self.teleop_keys["left"]:
|
||||
self.pressed_keys["left"] = False
|
||||
elif key.char == self.teleop_keys["right"]:
|
||||
self.pressed_keys["right"] = False
|
||||
elif key.char == self.teleop_keys["rotate_left"]:
|
||||
self.pressed_keys["rotate_left"] = False
|
||||
elif key.char == self.teleop_keys["rotate_right"]:
|
||||
self.pressed_keys["rotate_right"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
if not self.leader_arms:
|
||||
raise ValueError("MobileManipulator has no leader arm to connect.")
|
||||
for name in self.leader_arms:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.calibrate_leader()
|
||||
|
||||
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
||||
self.context = zmq.Context()
|
||||
self.cmd_socket = self.context.socket(zmq.PUSH)
|
||||
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
||||
self.cmd_socket.connect(connection_string)
|
||||
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
self.video_socket = self.context.socket(zmq.PULL)
|
||||
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
||||
self.video_socket.connect(video_connection)
|
||||
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
print(
|
||||
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
def load_or_run_calibration_(self, name, arm, arm_type):
|
||||
arm_id = get_arm_id(name, arm_type)
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
|
||||
def calibrate_leader(self):
|
||||
for name, arm in self.leader_arms.items():
|
||||
# Connect the bus
|
||||
arm.connect()
|
||||
|
||||
# Disable torque on all motors
|
||||
for motor_id in arm.motors:
|
||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
||||
|
||||
# Now run calibration
|
||||
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
||||
arm.set_calibration(calibration)
|
||||
|
||||
def calibrate_follower(self):
|
||||
for name, bus in self.follower_arms.items():
|
||||
bus.connect()
|
||||
|
||||
# Disable torque on all motors
|
||||
for motor_id in bus.motors:
|
||||
bus.write("Torque_Enable", 0, motor_id)
|
||||
|
||||
# Then filter out wheels
|
||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||
if not arm_only_dict:
|
||||
continue
|
||||
|
||||
original_motors = bus.motors
|
||||
bus.motors = arm_only_dict
|
||||
|
||||
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
||||
bus.set_calibration(calibration)
|
||||
|
||||
bus.motors = original_motors
|
||||
|
||||
def _get_data(self):
|
||||
"""
|
||||
Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||
the *latest* message, returning frames, speed, and arm state. If
|
||||
nothing arrives for any field, use the last known values.
|
||||
"""
|
||||
frames = {}
|
||||
present_speed = {}
|
||||
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
||||
|
||||
# Poll up to 15 ms
|
||||
poller = zmq.Poller()
|
||||
poller.register(self.video_socket, zmq.POLLIN)
|
||||
socks = dict(poller.poll(15))
|
||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||
# No new data arrived → reuse ALL old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Drain all messages, keep only the last
|
||||
last_msg = None
|
||||
while True:
|
||||
try:
|
||||
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
||||
last_msg = obs_string
|
||||
except zmq.Again:
|
||||
break
|
||||
|
||||
if not last_msg:
|
||||
# No new message → also reuse old
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Decode only the final message
|
||||
try:
|
||||
observation = json.loads(last_msg)
|
||||
|
||||
images_dict = observation.get("images", {})
|
||||
new_speed = observation.get("present_speed", {})
|
||||
new_arm_state = observation.get("follower_arm_state", None)
|
||||
|
||||
# Convert images
|
||||
for cam_name, image_b64 in images_dict.items():
|
||||
if image_b64:
|
||||
jpg_data = base64.b64decode(image_b64)
|
||||
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
||||
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if frame_candidate is not None:
|
||||
frames[cam_name] = frame_candidate
|
||||
|
||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = new_speed
|
||||
self.last_present_speed = new_speed
|
||||
else:
|
||||
frames = self.last_frames
|
||||
|
||||
remote_arm_state_tensor = self.last_remote_arm_state
|
||||
|
||||
present_speed = self.last_present_speed
|
||||
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
return frames, present_speed, remote_arm_state_tensor
|
||||
|
||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||
if present_speed:
|
||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||
if "1" in decoded:
|
||||
state_tensor[0] = decoded["1"]
|
||||
if "2" in decoded:
|
||||
state_tensor[1] = decoded["2"]
|
||||
if "3" in decoded:
|
||||
state_tensor[2] = decoded["3"]
|
||||
return state_tensor
|
||||
|
||||
def teleop_step(
|
||||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
||||
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
arm_positions = []
|
||||
for name in self.leader_arms:
|
||||
pos = self.leader_arms[name].read("Present_Position")
|
||||
pos_tensor = torch.from_numpy(pos).float()
|
||||
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
|
||||
arm_positions.extend(pos_tensor.tolist())
|
||||
|
||||
# (The rest of your code for generating wheel commands remains unchanged)
|
||||
x_cmd = 0.0 # m/s forward/backward
|
||||
y_cmd = 0.0 # m/s lateral
|
||||
theta_cmd = 0.0 # deg/s rotation
|
||||
if self.pressed_keys["forward"]:
|
||||
x_cmd += xy_speed
|
||||
if self.pressed_keys["backward"]:
|
||||
x_cmd -= xy_speed
|
||||
if self.pressed_keys["left"]:
|
||||
y_cmd += xy_speed
|
||||
if self.pressed_keys["right"]:
|
||||
y_cmd -= xy_speed
|
||||
if self.pressed_keys["rotate_left"]:
|
||||
theta_cmd += theta_speed
|
||||
if self.pressed_keys["rotate_right"]:
|
||||
theta_cmd -= theta_speed
|
||||
|
||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
||||
self.cmd_socket.send_string(json.dumps(message))
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
obs_dict = self.capture_observation()
|
||||
|
||||
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
||||
|
||||
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
||||
wheel_velocity_mm = (
|
||||
wheel_velocity_tuple[0] * 1000.0,
|
||||
wheel_velocity_tuple[1] * 1000.0,
|
||||
wheel_velocity_tuple[2],
|
||||
)
|
||||
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
||||
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
||||
action_dict = {"action": action_tensor}
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||
|
||||
body_state = self.wheel_raw_to_body(present_speed)
|
||||
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
|
||||
obs_dict = {"observation.state": combined_state_tensor}
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, cam in self.cameras.items():
|
||||
frame = frames.get(cam_name, None)
|
||||
if frame is None:
|
||||
# Create a black image using the camera's configured width, height, and channels
|
||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
# Ensure the action tensor has at least 9 elements:
|
||||
# - First 6: arm positions.
|
||||
# - Last 3: base commands.
|
||||
if action.numel() < 9:
|
||||
# Pad with zeros if there are not enough elements.
|
||||
padded = torch.zeros(9, dtype=action.dtype)
|
||||
padded[: action.numel()] = action
|
||||
action = padded
|
||||
|
||||
# Extract arm and base actions.
|
||||
arm_actions = action[:6].flatten()
|
||||
base_actions = action[6:].flatten()
|
||||
|
||||
x_cmd_mm = base_actions[0].item() # mm/s
|
||||
y_cmd_mm = base_actions[1].item() # mm/s
|
||||
theta_cmd = base_actions[2].item() # deg/s
|
||||
|
||||
# Convert mm/s to m/s for the kinematics calculations.
|
||||
x_cmd = x_cmd_mm / 1000.0 # m/s
|
||||
y_cmd = y_cmd_mm / 1000.0 # m/s
|
||||
|
||||
# Compute wheel commands from body commands.
|
||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
arm_positions_list = arm_actions.tolist()
|
||||
|
||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
||||
self.cmd_socket.send_string(json.dumps(message))
|
||||
|
||||
return action
|
||||
|
||||
def print_logs(self):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected.")
|
||||
if self.cmd_socket:
|
||||
stop_cmd = {
|
||||
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
||||
"arm_positions": {},
|
||||
}
|
||||
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
||||
self.cmd_socket.close()
|
||||
if self.video_socket:
|
||||
self.video_socket.close()
|
||||
if self.context:
|
||||
self.context.term()
|
||||
if PYNPUT_AVAILABLE:
|
||||
self.listener.stop()
|
||||
self.is_connected = False
|
||||
print("[INFO] Disconnected from remote robot.")
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
if PYNPUT_AVAILABLE:
|
||||
self.listener.stop()
|
||||
|
||||
@staticmethod
|
||||
def degps_to_raw(degps: float) -> int:
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
speed_in_steps = abs(degps) * steps_per_deg
|
||||
speed_int = int(round(speed_in_steps))
|
||||
if speed_int > 0x7FFF:
|
||||
speed_int = 0x7FFF
|
||||
if degps < 0:
|
||||
return speed_int | 0x8000
|
||||
else:
|
||||
return speed_int & 0x7FFF
|
||||
|
||||
@staticmethod
|
||||
def raw_to_degps(raw_speed: int) -> float:
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
magnitude = raw_speed & 0x7FFF
|
||||
degps = magnitude / steps_per_deg
|
||||
if raw_speed & 0x8000:
|
||||
degps = -degps
|
||||
return degps
|
||||
|
||||
def body_to_wheel_raw(
|
||||
self,
|
||||
x_cmd: float,
|
||||
y_cmd: float,
|
||||
theta_cmd: float,
|
||||
wheel_radius: float = 0.05,
|
||||
base_radius: float = 0.125,
|
||||
max_raw: int = 3000,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert desired body-frame velocities into wheel raw commands.
|
||||
|
||||
Parameters:
|
||||
x_cmd : Linear velocity in x (m/s).
|
||||
y_cmd : Linear velocity in y (m/s).
|
||||
theta_cmd : Rotational velocity (deg/s).
|
||||
wheel_radius: Radius of each wheel (meters).
|
||||
base_radius : Distance from the center of rotation to each wheel (meters).
|
||||
max_raw : Maximum allowed raw command (ticks) per wheel.
|
||||
|
||||
Returns:
|
||||
A dictionary with wheel raw commands:
|
||||
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
||||
|
||||
Notes:
|
||||
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
||||
- The raw command is computed from the wheels angular speed in deg/s
|
||||
using degps_to_raw(). If any command exceeds max_raw, all commands
|
||||
are scaled down proportionally.
|
||||
"""
|
||||
# Convert rotational velocity from deg/s to rad/s.
|
||||
theta_rad = theta_cmd * (np.pi / 180.0)
|
||||
# Create the body velocity vector [x, y, theta_rad].
|
||||
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
||||
# The third column (base_radius) accounts for the effect of rotation.
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
|
||||
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
||||
wheel_linear_speeds = m.dot(velocity_vector)
|
||||
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
||||
|
||||
# Convert wheel angular speeds from rad/s to deg/s.
|
||||
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
||||
|
||||
# Scaling
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
||||
max_raw_computed = max(raw_floats)
|
||||
if max_raw_computed > max_raw:
|
||||
scale = max_raw / max_raw_computed
|
||||
wheel_degps = wheel_degps * scale
|
||||
|
||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||
|
||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||
|
||||
def wheel_raw_to_body(
|
||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
) -> tuple:
|
||||
"""
|
||||
Convert wheel raw command feedback back into body-frame velocities.
|
||||
|
||||
Parameters:
|
||||
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
||||
wheel_radius: Radius of each wheel (meters).
|
||||
base_radius : Distance from the robot center to each wheel (meters).
|
||||
|
||||
Returns:
|
||||
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
||||
x_cmd : Linear velocity in x (m/s).
|
||||
y_cmd : Linear velocity in y (m/s).
|
||||
theta_cmd : Rotational velocity in deg/s.
|
||||
"""
|
||||
# Extract the raw values in order.
|
||||
raw_list = [
|
||||
int(wheel_raw.get("left_wheel", 0)),
|
||||
int(wheel_raw.get("back_wheel", 0)),
|
||||
int(wheel_raw.get("right_wheel", 0)),
|
||||
]
|
||||
|
||||
# Convert each raw command back to an angular speed in deg/s.
|
||||
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
||||
# Convert from deg/s to rad/s.
|
||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||
wheel_linear_speeds = wheel_radps * wheel_radius
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
|
||||
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
||||
m_inv = np.linalg.inv(m)
|
||||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||
return (x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
|
||||
class LeKiwi:
|
||||
def __init__(self, motor_bus):
|
||||
"""
|
||||
Initializes the LeKiwi with Feetech motors bus.
|
||||
"""
|
||||
self.motor_bus = motor_bus
|
||||
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
||||
|
||||
# Initialize motors in velocity mode.
|
||||
self.motor_bus.write("Lock", 0)
|
||||
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
||||
self.motor_bus.write("Lock", 1)
|
||||
print("Motors set to velocity mode.")
|
||||
|
||||
def read_velocity(self):
|
||||
"""
|
||||
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
||||
"""
|
||||
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
||||
return {
|
||||
"left_wheel": int(raw_speeds[0]),
|
||||
"back_wheel": int(raw_speeds[1]),
|
||||
"right_wheel": int(raw_speeds[2]),
|
||||
}
|
||||
|
||||
def set_velocity(self, command_speeds):
|
||||
"""
|
||||
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
||||
The order of speeds must correspond to self.motor_ids.
|
||||
"""
|
||||
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
||||
|
||||
def stop(self):
|
||||
"""Stops the robot by setting all motor speeds to zero."""
|
||||
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
||||
print("Motors stopped.")
|
||||
@@ -108,7 +108,7 @@ class StretchRobot(StretchAPI):
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
@@ -153,7 +153,7 @@ class StretchRobot(StretchAPI):
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
# Populate output dictionaries
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
|
||||
@@ -4,6 +4,7 @@ from lerobot.common.robot_devices.robots.configs import (
|
||||
AlohaRobotConfig,
|
||||
KochBimanualRobotConfig,
|
||||
KochRobotConfig,
|
||||
LeKiwiRobotConfig,
|
||||
ManipulatorRobotConfig,
|
||||
MossRobotConfig,
|
||||
RobotConfig,
|
||||
@@ -45,6 +46,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
return So100RobotConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
@@ -54,6 +57,10 @@ def make_robot_from_config(config: RobotConfig):
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
return ManipulatorRobot(config)
|
||||
elif isinstance(config, LeKiwiRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
|
||||
return MobileManipulator(config)
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
|
||||
@@ -13,10 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import imageio
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
@@ -25,3 +31,81 @@ def write_video(video_path, stacked_frames, fps):
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
"""
|
||||
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
||||
corresponding values (strictly matching structure and types).
|
||||
Tuples in `obj` are expected to be lists in the JSON data, which will be
|
||||
converted back into tuples.
|
||||
"""
|
||||
with open(fpath, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
def _deserialize(target, source):
|
||||
"""
|
||||
Recursively overwrite the structure in `target` with data from `source`,
|
||||
performing strict checks on structure and type.
|
||||
Returns the updated version of `target` (especially important for tuples).
|
||||
"""
|
||||
|
||||
# If the target is a dictionary, source must be a dictionary as well.
|
||||
if isinstance(target, dict):
|
||||
if not isinstance(source, dict):
|
||||
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
|
||||
|
||||
# Check that they have exactly the same set of keys.
|
||||
if target.keys() != source.keys():
|
||||
raise ValueError(
|
||||
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
|
||||
)
|
||||
|
||||
# Recursively update each key.
|
||||
for k in target:
|
||||
target[k] = _deserialize(target[k], source[k])
|
||||
|
||||
return target
|
||||
|
||||
# If the target is a list, source must be a list as well.
|
||||
elif isinstance(target, list):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
|
||||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
target[i] = _deserialize(target[i], source[i])
|
||||
|
||||
return target
|
||||
|
||||
# If the target is a tuple, the source must be a list in JSON,
|
||||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
for t_item, s_item in zip(target, source, strict=False):
|
||||
converted_items.append(_deserialize(t_item, s_item))
|
||||
|
||||
# Return a brand new tuple (tuples are immutable in Python).
|
||||
return tuple(converted_items)
|
||||
|
||||
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
|
||||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
updated_obj = _deserialize(obj, data)
|
||||
return updated_obj
|
||||
|
||||
163
lerobot/common/utils/logging_utils.py
Normal file
163
lerobot/common/utils/logging_utils.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.utils.utils import format_big_number
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val: float, n: int = 1) -> None:
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name}:{avg" + self.fmt + "}"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class MetricsTracker:
|
||||
"""
|
||||
A helper class to track and log metrics over time.
|
||||
|
||||
Usage pattern:
|
||||
|
||||
```python
|
||||
# initialize, potentially with non-zero initial step (e.g. if resuming run)
|
||||
metrics = {"loss": AverageMeter("loss", ":.3f")}
|
||||
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
|
||||
|
||||
# update metrics derived from step (samples, episodes, epochs) at each training step
|
||||
train_metrics.step()
|
||||
|
||||
# update various metrics
|
||||
loss = policy.forward(batch)
|
||||
train_metrics.loss = loss
|
||||
|
||||
# display current metrics
|
||||
logging.info(train_metrics)
|
||||
|
||||
# export for wandb
|
||||
wandb.log(train_metrics.to_dict())
|
||||
|
||||
# reset averages after logging
|
||||
train_metrics.reset_averages()
|
||||
```
|
||||
"""
|
||||
|
||||
__keys__ = [
|
||||
"_batch_size",
|
||||
"_num_frames",
|
||||
"_avg_samples_per_ep",
|
||||
"metrics",
|
||||
"steps",
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
super().__setattr__(name, value)
|
||||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
# number of samples seen during training
|
||||
f"smpl:{format_big_number(self.samples)}",
|
||||
# number of episodes seen during training
|
||||
f"ep:{format_big_number(self.episodes)}",
|
||||
# number of time all unique samples are seen
|
||||
f"epch:{self.epochs:.2f}",
|
||||
*[str(m) for m in self.metrics.values()],
|
||||
]
|
||||
return " ".join(display_list)
|
||||
|
||||
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
|
||||
"""
|
||||
Returns the current metric values (or averages if `use_avg=True`) as a dict.
|
||||
"""
|
||||
return {
|
||||
"steps": self.steps,
|
||||
"samples": self.samples,
|
||||
"episodes": self.episodes,
|
||||
"epochs": self.epochs,
|
||||
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
|
||||
}
|
||||
|
||||
def reset_averages(self) -> None:
|
||||
"""Resets average meters."""
|
||||
for m in self.metrics.values():
|
||||
m.reset()
|
||||
191
lerobot/common/utils/random_utils.py
Normal file
191
lerobot/common/utils/random_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import RNG_STATE
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
|
||||
|
||||
|
||||
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
py_state = random.getstate()
|
||||
return {
|
||||
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
|
||||
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
np_state = np.random.get_state()
|
||||
# Ensure no breaking changes from numpy
|
||||
assert np_state[0] == "MT19937"
|
||||
return {
|
||||
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
|
||||
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
|
||||
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
|
||||
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
|
||||
"""
|
||||
np_state = (
|
||||
"MT19937",
|
||||
rng_state_dict["np_rng_state_values"].numpy(),
|
||||
rng_state_dict["np_rng_state_index"].item(),
|
||||
rng_state_dict["np_rng_has_gauss"].item(),
|
||||
rng_state_dict["np_rng_cached_gaussian"].item(),
|
||||
)
|
||||
np.random.set_state(np_state)
|
||||
|
||||
|
||||
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
|
||||
if torch.cuda.is_available():
|
||||
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return torch_rng_state_dict
|
||||
|
||||
|
||||
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
|
||||
"""
|
||||
torch.set_rng_state(rng_state_dict["torch_rng_state"])
|
||||
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
|
||||
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
|
||||
|
||||
|
||||
def serialize_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
|
||||
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
|
||||
"""
|
||||
py_rng_state_dict = serialize_python_rng_state()
|
||||
np_rng_state_dict = serialize_numpy_rng_state()
|
||||
torch_rng_state_dict = serialize_torch_rng_state()
|
||||
|
||||
return {
|
||||
**py_rng_state_dict,
|
||||
**np_rng_state_dict,
|
||||
**torch_rng_state_dict,
|
||||
}
|
||||
|
||||
|
||||
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
|
||||
`serialize_rng_state()`.
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
deserialize_torch_rng_state(torch_rng_state_dict)
|
||||
|
||||
|
||||
def save_rng_state(save_dir: Path) -> None:
|
||||
rng_state_dict = serialize_rng_state()
|
||||
flat_rng_state_dict = flatten_dict(rng_state_dict)
|
||||
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
|
||||
|
||||
|
||||
def load_rng_state(save_dir: Path) -> None:
|
||||
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
|
||||
rng_state_dict = unflatten_dict(flat_rng_state_dict)
|
||||
deserialize_rng_state(rng_state_dict)
|
||||
|
||||
|
||||
def get_rng_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_rng_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state_dict = get_rng_state()
|
||||
set_seed(seed)
|
||||
yield None
|
||||
set_rng_state(random_state_dict)
|
||||
161
lerobot/common/utils/train_utils.py
Normal file
161
lerobot/common/utils/train_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.datasets.utils import load_json, write_json
|
||||
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def get_step_identifier(step: int, total_steps: int) -> str:
|
||||
num_digits = max(6, len(str(total_steps)))
|
||||
return f"{step:0{num_digits}d}"
|
||||
|
||||
|
||||
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
|
||||
"""Returns the checkpoint sub-directory corresponding to the step number."""
|
||||
step_identifier = get_step_identifier(step, total_steps)
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
training_step = load_json(save_dir / TRAINING_STEP)
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
last_checkpoint_dir.unlink()
|
||||
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
||||
last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
checkpoint_dir: Path,
|
||||
step: int,
|
||||
cfg: TrainPipelineConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
005000/ # training step at checkpoint
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
├── rng_state.safetensors # rng states
|
||||
├── scheduler_state.json # scheduler state
|
||||
└── training_step.json # training step
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config used for this run.
|
||||
step (int): The training step at that checkpoint.
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
checkpoint_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
|
||||
Args:
|
||||
save_dir (Path): The directory to save artifacts to.
|
||||
train_step (int): Current training step.
|
||||
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
if scheduler is not None:
|
||||
save_scheduler_state(scheduler, save_dir)
|
||||
|
||||
|
||||
def load_training_state(
|
||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
||||
) -> tuple[int, Optimizer, LRScheduler | None]:
|
||||
"""
|
||||
Loads the training step, optimizer state, scheduler state, and rng state.
|
||||
This is used to resume a training run.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
||||
optimizer (Optimizer): The optimizer to load the state_dict to.
|
||||
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
||||
|
||||
Returns:
|
||||
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
|
||||
state_dict loaded.
|
||||
"""
|
||||
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
if not training_state_dir.is_dir():
|
||||
raise NotADirectoryError(training_state_dir)
|
||||
|
||||
load_rng_state(training_state_dir)
|
||||
step = load_training_step(training_state_dir)
|
||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||
if scheduler is not None:
|
||||
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
||||
|
||||
return step, optimizer, scheduler
|
||||
@@ -17,12 +17,10 @@ import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import subprocess
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -106,59 +104,6 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state_dict = get_global_random_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -221,23 +166,31 @@ def capture_timestamp_utc():
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
if not blocking:
|
||||
cmd += " &"
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
if blocking:
|
||||
cmd += " --wait"
|
||||
elif platform.system() == "Windows":
|
||||
# TODO(rcadene): Make blocking option work for Windows
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
system = platform.system()
|
||||
|
||||
os.system(cmd)
|
||||
if system == "Darwin":
|
||||
cmd = ["say", text]
|
||||
|
||||
elif system == "Linux":
|
||||
cmd = ["spd-say", text]
|
||||
if blocking:
|
||||
cmd.append("--wait")
|
||||
|
||||
elif system == "Windows":
|
||||
cmd = [
|
||||
"PowerShell",
|
||||
"-Command",
|
||||
"Add-Type -AssemblyName System.Speech; "
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
||||
]
|
||||
|
||||
else:
|
||||
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
||||
|
||||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
@@ -257,5 +210,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
||||
return shape
|
||||
|
||||
|
||||
def has_method(cls: object, method_name: str):
|
||||
def has_method(cls: object, method_name: str) -> bool:
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
|
||||
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
"""
|
||||
Return True if a given string can be converted to a numpy dtype.
|
||||
"""
|
||||
try:
|
||||
# Attempt to convert the string to a numpy dtype
|
||||
np.dtype(dtype_str)
|
||||
return True
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
121
lerobot/common/utils/wandb_utils.py
Normal file
121
lerobot/common/utils/wandb_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(log_dir / "wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
def get_safe_wandb_artifact_name(name: str):
|
||||
"""WandB artifacts don't accept ":" or "/" in their name."""
|
||||
return name.replace(":", "_").replace("/", "_")
|
||||
|
||||
|
||||
class WandBLogger:
|
||||
"""A helper class to log object using wandb."""
|
||||
|
||||
def __init__(self, cfg: TrainPipelineConfig):
|
||||
self.cfg = cfg.wandb
|
||||
self.log_dir = cfg.output_dir
|
||||
self.job_name = cfg.job_name
|
||||
self.env_fps = cfg.env.fps if cfg.env else None
|
||||
self._group = cfg_to_group(cfg)
|
||||
|
||||
# Set up WandB.
|
||||
os.environ["WANDB_SILENT"] = "True"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=self.cfg.project,
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
"""Checkpoints the policy to wandb."""
|
||||
if self.cfg.disable_artifact:
|
||||
return
|
||||
|
||||
step_id = checkpoint_dir.name
|
||||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
@@ -27,11 +27,13 @@ class DatasetConfig:
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
local_files_only: bool = False
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = "pyav"
|
||||
|
||||
|
||||
@@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineConfig:
|
||||
steps: int = 100_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnlineConfig:
|
||||
"""
|
||||
The online training loop looks something like:
|
||||
|
||||
```python
|
||||
for i in range(steps):
|
||||
do_online_rollout_and_update_online_buffer()
|
||||
for j in range(steps_between_rollouts):
|
||||
batch = next(dataloader_with_offline_and_online_data)
|
||||
loss = policy(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
Note that the online training loop adopts most of the options from the offline loop unless specified
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
steps: int = 0
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
rollout_n_episodes: int = 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
|
||||
rollout_batch_size: int = 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
steps_between_rollouts: int | None = None
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
sampling_ratio: float = 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
env_seed: int | None = None
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
buffer_capacity: int | None = None
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
buffer_seed_size: int = 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_rollout_async: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.steps == 0:
|
||||
return
|
||||
|
||||
if self.steps_between_rollouts is None:
|
||||
raise ValueError(
|
||||
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
|
||||
)
|
||||
if self.env_seed is None:
|
||||
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
|
||||
if self.buffer_capacity is None:
|
||||
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
@@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin):
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
offline: OfflineConfig = field(default_factory=OfflineConfig)
|
||||
online: OnlineConfig = field(default_factory=OnlineConfig)
|
||||
use_policy_training_preset: bool = True
|
||||
optimizer: OptimizerConfig | None = None
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
@@ -148,6 +85,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError("A config_path is expected when resuming a run.")
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
@@ -160,7 +102,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
|
||||
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
||||
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
||||
)
|
||||
elif not self.output_dir:
|
||||
@@ -168,11 +110,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if self.online.steps > 0:
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
if self.env is None:
|
||||
raise ValueError("An environment is required for online training")
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
@@ -185,6 +124,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@@ -77,6 +77,13 @@ python lerobot/scripts/control_robot.py record \
|
||||
--control.reset_time_s=10
|
||||
```
|
||||
|
||||
- For remote controlled robots like LeKiwi, run this script on the robot edge device (e.g. RaspBerryPi):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=lekiwi \
|
||||
--control.type=remote_robot
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
|
||||
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
||||
@@ -85,7 +92,6 @@ python lerobot/scripts/control_robot.py record \
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
|
||||
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
@@ -127,6 +133,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
@@ -188,6 +195,16 @@ def calibrate(robot: Robot, cfg: CalibrateControlConfig):
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
if robot.robot_type.startswith("lekiwi") and "main_follower" in arms:
|
||||
print("Calibrating only the lekiwi follower arm 'main_follower'...")
|
||||
robot.calibrate_follower()
|
||||
return
|
||||
|
||||
if robot.robot_type.startswith("lekiwi") and "main_leader" in arms:
|
||||
print("Calibrating only the lekiwi leader arm 'main_leader'...")
|
||||
robot.calibrate_leader()
|
||||
return
|
||||
|
||||
# Calling `connect` automatically runs calibration
|
||||
# when the calibration file is missing
|
||||
robot.connect()
|
||||
@@ -216,7 +233,6 @@ def record(
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
@@ -263,8 +279,8 @@ def record(
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
@@ -272,6 +288,7 @@ def record(
|
||||
device=cfg.device,
|
||||
use_amp=cfg.use_amp,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
@@ -282,7 +299,7 @@ def record(
|
||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
reset_environment(robot, events, cfg.reset_time_s)
|
||||
reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
@@ -291,7 +308,7 @@ def record(
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode(cfg.single_task)
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
@@ -300,11 +317,6 @@ def record(
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if cfg.run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(cfg.run_compute_stats)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
@@ -320,9 +332,7 @@ def replay(
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
|
||||
)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
@@ -357,6 +367,10 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||
record(robot, cfg.control)
|
||||
elif isinstance(cfg.control, ReplayControlConfig):
|
||||
replay(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||
|
||||
run_lekiwi(cfg.robot)
|
||||
|
||||
if robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
|
||||
@@ -59,8 +59,8 @@ python lerobot/scripts/control_sim_robot.py record \
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
|
||||
- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment.
|
||||
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
||||
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
@@ -131,7 +131,7 @@ def none_or_int(value):
|
||||
|
||||
def init_sim_calibration(robot, cfg):
|
||||
# Constants necessary for transforming the joint pos of the real robot to the sim
|
||||
# depending on the robot discription used in that sim.
|
||||
# depending on the robot description used in that sim.
|
||||
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
|
||||
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||
@@ -445,7 +445,7 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||
"Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; "
|
||||
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||
|
||||
@@ -61,21 +61,21 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
@@ -125,9 +125,6 @@ def rollout(
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
if hasattr(policy, "use_ema_modules"):
|
||||
policy.use_ema_modules()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -154,7 +151,9 @@ def rollout(
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
@@ -455,7 +454,7 @@ def _compile_episode_data(
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Check device is available
|
||||
@@ -463,9 +462,9 @@ def eval(cfg: EvalPipelineConfig):
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(cfg.seed)
|
||||
|
||||
log_output_dir(cfg.output_dir)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
@@ -500,4 +499,4 @@ def eval(cfg: EvalPipelineConfig):
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
eval()
|
||||
eval_main()
|
||||
|
||||
@@ -175,7 +175,7 @@ def push_dataset_to_hub(
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
# Send warning if local_dir isn't well formated
|
||||
# Send warning if local_dir isn't well formatted
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
|
||||
@@ -15,61 +15,64 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_dtype,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
grad_clip_norm,
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
@@ -87,9 +90,6 @@ def update_policy(
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if hasattr(policy, "update_ema_modules"):
|
||||
policy.update_ema_modules()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
@@ -98,113 +98,34 @@ def update_policy(
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def log_train_info(
|
||||
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
|
||||
):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
update_s = info["update_s"]
|
||||
dataloading_s = info["dataloading_s"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
f"smpl:{format_big_number(num_samples)}",
|
||||
# number of episodes seen during training
|
||||
f"ep:{format_big_number(num_episodes)}",
|
||||
# number of time all unique samples are seen
|
||||
f"epch:{num_epochs:.2f}",
|
||||
f"loss:{loss:.3f}",
|
||||
f"grdn:{grad_norm:.3f}",
|
||||
f"lr:{lr:0.1e}",
|
||||
# in seconds
|
||||
f"updt_s:{update_s:.3f}",
|
||||
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
info["step"] = step
|
||||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="train")
|
||||
|
||||
|
||||
def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
eval_s = info["eval_s"]
|
||||
avg_sum_reward = info["avg_sum_reward"]
|
||||
pc_success = info["pc_success"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
f"smpl:{format_big_number(num_samples)}",
|
||||
# number of episodes seen during training
|
||||
f"ep:{format_big_number(num_episodes)}",
|
||||
# number of time all unique samples are seen
|
||||
f"epch:{num_epochs:.2f}",
|
||||
f"∑rwrd:{avg_sum_reward:.3f}",
|
||||
f"success:{pc_success:.1f}%",
|
||||
f"eval_s:{eval_s:.3f}",
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
info["step"] = step
|
||||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg)
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
@@ -218,7 +139,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=offline_dataset.meta,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
@@ -233,65 +154,29 @@ def train(cfg: TrainPipelineConfig):
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(cfg.output_dir)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
|
||||
logging.info(f"{cfg.online.steps=}")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
|
||||
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.save_checkpoint and (
|
||||
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_checkpoint(
|
||||
step,
|
||||
step_identifier,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
if getattr(cfg.policy, "drop_n_last_frames", None):
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
@@ -303,23 +188,30 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
policy.train()
|
||||
|
||||
if hasattr(policy, "init_ema_modules"):
|
||||
policy.init_ema_modules()
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.offline.steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
||||
)
|
||||
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
@@ -329,231 +221,60 @@ def train(cfg: TrainPipelineConfig):
|
||||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
offline_step += 1 # noqa: SIM113
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if cfg.online.steps == 0:
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
return
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
# Online training.
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
|
||||
online_buffer_path = logger.log_dir / "online_buffer"
|
||||
if cfg.resume and not online_buffer_path.exists():
|
||||
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||
# buffer.
|
||||
logging.warning(
|
||||
"When online training is resumed, we load the latest online buffer from the prior run, "
|
||||
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
|
||||
"was made. This is because the online buffer is updated on disk during training, independently "
|
||||
"of our explicit checkpointing mechanisms."
|
||||
)
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.input_features.items()
|
||||
},
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.output_features.items()
|
||||
},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"task_index": {"shape": (), "dtype": np.dtype("int64")},
|
||||
# FIXME: 'task' is a string
|
||||
# "task": {"shape": (), "dtype": np.dtype("?")},
|
||||
# FIXME: 'next.success' is expected by pusht env but not xarm
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.online.buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
sampler_weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
sampler_weights,
|
||||
num_samples=len(concat_dataset),
|
||||
replacement=True,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
if cfg.online.do_rollout_async:
|
||||
# Lock and thread pool executor for asynchronous online rollouts.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
else:
|
||||
lock = None
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
rollout_start_seed = cfg.online.env_seed
|
||||
|
||||
while True:
|
||||
if online_step == cfg.online.steps:
|
||||
break
|
||||
|
||||
if online_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.online.rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
if len(offline_dataset.meta.tasks) > 1:
|
||||
raise NotImplementedError("Add support for multi task.")
|
||||
|
||||
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
total_num_frames = eval_info["episodes"]["index"].shape[0]
|
||||
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
|
||||
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
if lock is None:
|
||||
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
|
||||
else:
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
|
||||
continue
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.online.steps_between_rollouts):
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
dtype = get_safe_dtype(batch[key].dtype, device)
|
||||
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
lock=lock,
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock if lock is not None else nullcontext():
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if cfg.online.do_rollout_async and future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.online.steps:
|
||||
break
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
|
||||
@@ -111,9 +111,9 @@ def visualize_dataset(
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
output_dir is not None
|
||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
assert output_dir is not None, (
|
||||
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
)
|
||||
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
@@ -207,12 +207,6 @@ def main():
|
||||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
@@ -275,10 +269,9 @@ def main():
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -194,7 +194,7 @@ def run_server(
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
@@ -245,16 +245,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
header += [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
else:
|
||||
column_names = [f"motor_{i}" for i in range(dim_state)]
|
||||
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
columns.append({"key": column_name, "value": column_names})
|
||||
|
||||
header += column_names
|
||||
|
||||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
@@ -317,7 +318,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
||||
)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
dataset_info["repo_id"] = repo_id
|
||||
@@ -364,7 +367,7 @@ def visualize_dataset_html(
|
||||
template_folder=template_dir,
|
||||
)
|
||||
else:
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# Create a simlink from the dataset video folder containing mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
@@ -384,12 +387,6 @@ def main():
|
||||
default=None,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
@@ -445,15 +442,10 @@ def main():
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
local_files_only=cfg.local_files_only,
|
||||
revision=cfg.revision,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,21 +14,7 @@
|
||||
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
|
||||
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
|
||||
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
|
||||
<body class="flex flex-col md:flex-row h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
|
||||
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
|
||||
const { keyCode, key } = e;
|
||||
if (keyCode === 32 || key === ' ') {
|
||||
e.preventDefault();
|
||||
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
|
||||
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
|
||||
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
|
||||
const lowestEpisodeId = {{ episodes }}.at(0);
|
||||
const highestEpisodeId = {{ episodes }}.at(-1);
|
||||
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
|
||||
window.location.href = `./episode_${nextEpisodeId}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<body class="flex flex-col md:flex-row h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
||||
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
|
||||
@@ -52,25 +38,55 @@
|
||||
|
||||
<p>Episodes:</p>
|
||||
<!-- episodes menu for medium & large screens -->
|
||||
<ul class="ml-2 hidden md:block">
|
||||
{% for episode in episodes %}
|
||||
<li class="font-mono text-sm mt-0.5">
|
||||
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
|
||||
Episode {{ episode }}
|
||||
</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
<div class="ml-2 hidden md:block" x-data="episodePagination">
|
||||
<ul>
|
||||
<template x-for="episode in paginatedEpisodes" :key="episode">
|
||||
<li class="font-mono text-sm mt-0.5">
|
||||
<a :href="'episode_' + episode"
|
||||
:class="{'underline': true, 'font-bold -ml-1': episode == {{ episode_id }}}"
|
||||
x-text="'Episode ' + episode"></a>
|
||||
</li>
|
||||
</template>
|
||||
</ul>
|
||||
|
||||
<div class="flex items-center mt-3 text-xs" x-show="totalPages > 1">
|
||||
<button @click="prevPage()"
|
||||
class="px-2 py-1 bg-slate-800 rounded mr-2"
|
||||
:class="{'opacity-50 cursor-not-allowed': page === 1}"
|
||||
:disabled="page === 1">
|
||||
« Prev
|
||||
</button>
|
||||
<span class="font-mono mr-2" x-text="` ${page} / ${totalPages}`"></span>
|
||||
<button @click="nextPage()"
|
||||
class="px-2 py-1 bg-slate-800 rounded"
|
||||
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
|
||||
:disabled="page === totalPages">
|
||||
Next »
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- episodes menu for small screens -->
|
||||
<div class="flex overflow-x-auto md:hidden">
|
||||
{% for episode in episodes %}
|
||||
<p class="font-mono text-sm mt-0.5 border-r last:border-r-0 px-2 {% if episode_id == episode %}font-bold{% endif %}">
|
||||
<a href="episode_{{ episode }}" class="">
|
||||
{{ episode }}
|
||||
</a>
|
||||
</p>
|
||||
{% endfor %}
|
||||
<div class="flex overflow-x-auto md:hidden" x-data="episodePagination">
|
||||
<button @click="prevPage()"
|
||||
class="px-2 bg-slate-800 rounded mr-2"
|
||||
:class="{'opacity-50 cursor-not-allowed': page === 1}"
|
||||
:disabled="page === 1">«</button>
|
||||
<div class="flex">
|
||||
<template x-for="(episode, index) in paginatedEpisodes" :key="episode">
|
||||
<p class="font-mono text-sm mt-0.5 px-2"
|
||||
:class="{
|
||||
'font-bold': episode == {{ episode_id }},
|
||||
'border-r': index !== paginatedEpisodes.length - 1
|
||||
}">
|
||||
<a :href="'episode_' + episode" x-text="episode"></a>
|
||||
</p>
|
||||
</template>
|
||||
</div>
|
||||
<button @click="nextPage()"
|
||||
class="px-2 bg-slate-800 rounded ml-2"
|
||||
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
|
||||
:disabled="page === totalPages">» </button>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@@ -230,14 +246,16 @@
|
||||
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`${rowLabels[rowIndex]}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
<td x-show="cell" class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 justify-between px-2" :class="{ 'hidden': cell.isNull }">
|
||||
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
|
||||
<span x-text="`${!cell.isNull ? cell.value.toFixed(2) : null}`"
|
||||
<div class="flex gap-x-2 justify-between px-2" :class="{ 'hidden': cell.isNull }">
|
||||
<div class="flex gap-x-2">
|
||||
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
|
||||
<span x-text="`${!cell.isNull ? cell.label : null}`"></span>
|
||||
</div>
|
||||
<span class="w-14 text-right" x-text="`${!cell.isNull ? (typeof cell.value === 'number' ? cell.value.toFixed(2) : cell.value) : null}`"
|
||||
:style="`color: ${cell.color}`"></span>
|
||||
</div>
|
||||
</td>
|
||||
@@ -279,7 +297,6 @@
|
||||
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
|
||||
videosKeysSelected: [],
|
||||
columns: {{ columns | tojson }},
|
||||
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
@@ -452,6 +469,68 @@
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
document.addEventListener('alpine:init', () => {
|
||||
// Episode pagination component
|
||||
Alpine.data('episodePagination', () => ({
|
||||
episodes: {{ episodes }},
|
||||
pageSize: 100,
|
||||
page: 1,
|
||||
|
||||
init() {
|
||||
// Find which page contains the current episode_id
|
||||
const currentEpisodeId = {{ episode_id }};
|
||||
const episodeIndex = this.episodes.indexOf(currentEpisodeId);
|
||||
if (episodeIndex !== -1) {
|
||||
this.page = Math.floor(episodeIndex / this.pageSize) + 1;
|
||||
}
|
||||
},
|
||||
|
||||
get totalPages() {
|
||||
return Math.ceil(this.episodes.length / this.pageSize);
|
||||
},
|
||||
|
||||
get paginatedEpisodes() {
|
||||
const start = (this.page - 1) * this.pageSize;
|
||||
const end = start + this.pageSize;
|
||||
return this.episodes.slice(start, end);
|
||||
},
|
||||
|
||||
nextPage() {
|
||||
if (this.page < this.totalPages) {
|
||||
this.page++;
|
||||
}
|
||||
},
|
||||
|
||||
prevPage() {
|
||||
if (this.page > 1) {
|
||||
this.page--;
|
||||
}
|
||||
}
|
||||
}));
|
||||
});
|
||||
</script>
|
||||
|
||||
<script>
|
||||
window.addEventListener('keydown', (e) => {
|
||||
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
|
||||
const { keyCode, key } = e;
|
||||
|
||||
if (keyCode === 32 || key === ' ') {
|
||||
e.preventDefault();
|
||||
const btnPause = document.querySelector('[x-ref="btnPause"]');
|
||||
const btnPlay = document.querySelector('[x-ref="btnPlay"]');
|
||||
btnPause.classList.contains('hidden') ? btnPlay.click() : btnPause.click();
|
||||
} else if (key === 'ArrowDown' || key === 'ArrowUp') {
|
||||
const episodes = {{ episodes }}; // Access episodes directly from the Jinja template
|
||||
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
|
||||
const lowestEpisodeId = episodes.at(0);
|
||||
const highestEpisodeId = episodes.at(-1);
|
||||
if (nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId) {
|
||||
window.location.href = `./episode_${nextEpisodeId}`;
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
|
||||
|
||||
BIN
media/lekiwi/mobile_calib_rest.webp
Normal file
BIN
media/lekiwi/mobile_calib_rest.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 220 KiB |
BIN
media/lekiwi/mobile_calib_rotated.webp
Normal file
BIN
media/lekiwi/mobile_calib_rotated.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 272 KiB |
BIN
media/lekiwi/mobile_calib_zero.webp
Normal file
BIN
media/lekiwi/mobile_calib_zero.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 186 KiB |
BIN
media/lekiwi/motor_ids.webp
Normal file
BIN
media/lekiwi/motor_ids.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 185 KiB |
7939
poetry.lock
generated
7939
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
161
pyproject.toml
161
pyproject.toml
@@ -1,18 +1,24 @@
|
||||
[tool.poetry]
|
||||
[project.urls]
|
||||
homepage = "https://github.com/huggingface/lerobot"
|
||||
issues = "https://github.com/huggingface/lerobot/issues"
|
||||
discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.1.0"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
authors = [
|
||||
"Rémi Cadène <re.cadene@gmail.com>",
|
||||
"Simon Alibert <alibert.sim@gmail.com>",
|
||||
"Alexander Soare <alexander.soare159@gmail.com>",
|
||||
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
||||
"Adil Zouitine <adilzouitinegm@gmail.com>",
|
||||
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
||||
{name = "Rémi Cadène", email = "re.cadene@gmail.com"},
|
||||
{name = "Simon Alibert", email = "alibert.sim@gmail.com"},
|
||||
{name = "Alexander Soare", email = "alexander.soare159@gmail.com"},
|
||||
{name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr"},
|
||||
{name = "Adil Zouitine", email = "adilzouitinegm@gmail.com"},
|
||||
{name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com"},
|
||||
]
|
||||
repository = "https://github.com/huggingface/lerobot"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
license = {text = "Apache-2.0"}
|
||||
requires-python = ">=3.10"
|
||||
keywords = ["robotics", "deep learning", "pytorch"]
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
@@ -23,70 +29,59 @@ classifiers=[
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
dependencies = [
|
||||
"cmake>=3.29.0.1",
|
||||
"datasets>=2.19.0",
|
||||
"deepdiff>=7.0.1",
|
||||
"diffusers>=0.27.2",
|
||||
"draccus>=0.10.0",
|
||||
"einops>=0.8.0",
|
||||
"flask>=3.0.3",
|
||||
"gdown>=5.1.0",
|
||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
"h5py>=3.10.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||
"hydra-core>=1.3.2",
|
||||
"imageio[ffmpeg]>=2.34.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"numba>=0.59.0",
|
||||
"omegaconf>=2.3.0",
|
||||
"opencv-python>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=12.0.5",
|
||||
"pymunk>=6.6.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
||||
dora = ["gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pynput>=1.7.7"
|
||||
]
|
||||
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
||||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
termcolor = ">=2.4.0"
|
||||
wandb = ">=0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||
gdown = ">=5.1.0"
|
||||
einops = ">=0.8.0"
|
||||
pymunk = ">=6.6.0"
|
||||
zarr = ">=2.17.0"
|
||||
numba = ">=0.59.0"
|
||||
torch = ">=2.2.1"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.21.0"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.27.1"}
|
||||
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.5", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
debugpy = {version = ">=1.8.1", optional = true}
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
rerun-sdk = ">=0.21.0"
|
||||
deepdiff = ">=7.0.1"
|
||||
flask = ">=3.0.3"
|
||||
pandas = {version = ">=2.2.2", optional = true}
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
||||
pynput = {version = ">=1.7.7", optional = true}
|
||||
feetech-servo-sdk = {version = ">=1.0.0", optional = true}
|
||||
setuptools = {version = "!=71.0.1", optional = true} # TODO(rcadene, aliberts): 71.0.1 has a bug
|
||||
pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'", optional = true} # TODO(rcadene, aliberts): Fix on Mac
|
||||
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
|
||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
transformers = ">=4.48.0"
|
||||
draccus = {git = "https://github.com/dlwh/draccus.git"} # replace with draccus = ">=0.9.4" when https://github.com/dlwh/draccus/pull/29 is in release
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
dora = ["gym-dora"]
|
||||
pusht = ["gym-pusht"]
|
||||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
dev = ["pre-commit", "debugpy"]
|
||||
test = ["pytest", "pytest-cov", "pyserial"]
|
||||
umi = ["imagecodecs"]
|
||||
video_benchmark = ["scikit-image", "pandas"]
|
||||
dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
|
||||
pi0 = ["transformers"]
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
@@ -116,10 +111,32 @@ exclude = [
|
||||
"venv",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = [
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"lerobot/common/datasets/push_dataset_to_hub",
|
||||
"lerobot/common/datasets/v2/convert_dataset_v1_to_v2",
|
||||
"lerobot/common/policies/pi0/conversion_scripts",
|
||||
"lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-ignore-re = [
|
||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on" # spellchecker:<on|off>
|
||||
]
|
||||
default.extend-ignore-identifiers-re = [
|
||||
# Add individual words here to ignore them
|
||||
"2nd",
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -28,6 +28,7 @@ pytest_plugins = [
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
"tests.fixtures.optimizers",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7841afb9ef99c0601448c43a20c25eb029440c73816319c67c5d7e1c5cde2445
|
||||
size 136
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"codebase_version": "v1.6",
|
||||
"fps": 50,
|
||||
"video": true,
|
||||
"encoding": {
|
||||
"vcodec": "libsvtav1",
|
||||
"pix_fmt": "yuv420p",
|
||||
"g": 2,
|
||||
"crf": 30
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user